/*
* KnownVarianceNormalPeriodPriorDistribution.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.epidemiology.casetocase.periodpriors;
import dr.inference.loggers.LogColumn;
import dr.inference.model.Parameter;
import dr.math.distributions.NormalDistribution;
import dr.xml.*;
import java.util.ArrayList;
import java.util.Arrays;
/**
The assumption here is that the periods are drawn from a normal distribution with unknown mean and variance.
The hyperprior is the conjugate, normal-gamma distribution.
@author Matthew Hall
*/
public class KnownVarianceNormalPeriodPriorDistribution extends AbstractPeriodPriorDistribution {
public static final String NORMAL = "knownVarianceNormalPeriodPriorDistribution";
public static final String LOG = "log";
public static final String ID = "id";
// This gets confusing. The data is assumed to be normally distributed with mean mu and stdev sigma. Sigma is
// known. The prior on mu is that it is _also_ normally distributed with mean mu_0 and stdev sigma_0.
public static final String MU_0 = "mu0";
public static final String SIGMA = "sigma";
public static final String SIGMA_0 = "sigma0";
private NormalDistribution hyperprior;
private Parameter posteriorMean;
private Parameter posteriorVariance;
private double sigma;
private ArrayList<Double> dataValues;
private double[] currentParameters;
public KnownVarianceNormalPeriodPriorDistribution(String name, boolean log, double sigma,
NormalDistribution hyperprior){
super(name, log);
this.hyperprior = hyperprior;
posteriorVariance = new Parameter.Default(1);
posteriorMean = new Parameter.Default(1);
addVariable(posteriorVariance);
addVariable(posteriorMean);
this.sigma = sigma;
}
public KnownVarianceNormalPeriodPriorDistribution(String name, boolean log, double sigma,
double mu_0, double sigma_0){
this(name, log, sigma, new NormalDistribution(mu_0, sigma_0));
}
public void reset(){
dataValues = new ArrayList<Double>();
currentParameters[0] = hyperprior.getMean();
currentParameters[1] = hyperprior.getSD();
logL = 0;
}
public double calculateLogPosteriorProbability(double newValue, double minValue){
double out = calculateLogPosteriorPredictiveProbability(newValue);
if(minValue != Double.NEGATIVE_INFINITY){
out -= calculateLogPosteriorPredictiveCDF(minValue, true);
}
logL += out;
update(newValue);
return out;
}
public double calculateLogPosteriorCDF(double limit, boolean upper) {
return calculateLogPosteriorPredictiveCDF(limit, upper);
}
public double calculateLogPosteriorPredictiveProbability(double value){
double mean = currentParameters[0];
double sd = currentParameters[1];
return NormalDistribution.logPdf(value, mean, Math.sqrt(Math.pow(sd, 2) + Math.pow(sigma, 2)));
}
public double calculateLogPosteriorPredictiveCDF(double value, boolean upperTail){
double mean = currentParameters[0];
double sd = currentParameters[1];
double scaledValue = (value - mean)/Math.sqrt(Math.pow(sd, 2) + Math.pow(sigma, 2));
return upperTail ? NormalDistribution.standardCDF(-scaledValue, true) :
NormalDistribution.standardCDF(scaledValue, true);
}
private void update(double newData){
dataValues.add(newData);
double originalMean = hyperprior.getMean();
double originalSD = hyperprior.getSD();
double count = dataValues.size();
double dataMean = 0;
for(double value: dataValues){
dataMean += value;
}
dataMean /= count;
double newSD = Math.sqrt(1/(count/Math.pow(sigma,2) + 1/Math.pow(originalSD,2)));
double newMean = Math.pow(newSD,2)*(originalMean/Math.pow(originalSD,2) + count*dataMean/Math.pow(sigma,2));
currentParameters = new double[]{newMean, newSD};
}
public double calculateLogLikelihood(double[] values){
int count = values.length;
double mu_0 = hyperprior.getMean();
double sigma_0 = hyperprior.getSD();
double var = Math.pow(sigma, 2);
double var_0 = Math.pow(sigma_0, 2);
double sum = 0;
double sumOfSquares = 0;
for (Double infPeriod : values) {
sum += infPeriod;
sumOfSquares += Math.pow(infPeriod, 2);
}
double mean = sum/count;
posteriorMean.setParameterValue(0, ((mu_0/var_0) + sum/var)/(1/var_0 + count/var));
posteriorVariance.setParameterValue(0, 1/(1/var_0 + count/var));
logL = Math.log(sigma)
- count * Math.log(Math.sqrt(2*Math.PI)*sigma)
- Math.log(Math.sqrt(count*var_0 + var))
+ -sumOfSquares/(2*var) - Math.pow(mu_0, 2)/(2*var_0)
+ (Math.pow(sigma_0*count*mean/sigma, 2) + Math.pow(sigma*mu_0/sigma_0, 2) + 2*count*mean*mu_0)
/(2*(count*var_0 + var));
return logL;
}
public LogColumn[] getColumns() {
ArrayList<LogColumn> columns = new ArrayList<LogColumn>(Arrays.asList(super.getColumns()));
columns.add(new LogColumn.Abstract(getModelName()+"_posteriorMean"){
protected String getFormattedValue() {
return String.valueOf(posteriorMean.getParameterValue(0));
}
});
columns.add(new LogColumn.Abstract(getModelName()+"_posteriorVariance"){
protected String getFormattedValue() {
return String.valueOf(posteriorVariance.getParameterValue(0));
}
});
return columns.toArray(new LogColumn[columns.size()]);
}
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public String getParserName() {
return NORMAL;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String id = (String) xo.getAttribute(ID);
double mu_0 = xo.getDoubleAttribute(MU_0);
double sigma = xo.getDoubleAttribute(SIGMA);
double sigma_0 = xo.getDoubleAttribute(SIGMA_0);
boolean log;
log = xo.hasAttribute(LOG) ? xo.getBooleanAttribute(LOG) : false;
return new KnownVarianceNormalPeriodPriorDistribution(id, log, sigma, mu_0, sigma_0);
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
AttributeRule.newBooleanRule(LOG, true),
AttributeRule.newStringRule(ID, false),
AttributeRule.newDoubleRule(MU_0, false),
AttributeRule.newDoubleRule(SIGMA, false),
AttributeRule.newDoubleRule(SIGMA_0, false),
};
public String getParserDescription() {
return "Calculates the probability of a set of doubles being drawn from the prior posterior distribution" +
"of a normal distribution of unknown mean and known standard deviation sigma";
}
public Class getReturnType() {
return KnownVarianceNormalPeriodPriorDistribution.class;
}
};
}