/*
* NormalPeriodPriorDistribution.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.epidemiology.casetocase.periodpriors;
import dr.inference.loggers.LogColumn;
import dr.inference.model.Parameter;
import dr.math.distributions.NormalDistribution;
import dr.math.distributions.NormalGammaDistribution;
import dr.math.functionEval.GammaFunction;
import dr.xml.*;
import org.apache.commons.math.MathException;
import org.apache.commons.math.distribution.TDistributionImpl;
import java.util.ArrayList;
import java.util.Arrays;
/**
The assumption here is that the periods are drawn from a normal distribution with unknown mean and variance.
The hyperprior is the conjugate, normal-gamma distribution.
@author Matthew Hall
*/
public class NormalPeriodPriorDistribution extends AbstractPeriodPriorDistribution {
public static final String NORMAL = "normalPeriodPriorDistribution";
public static final String LOG = "log";
public static final String ID = "id";
public static final String MU = "mu";
public static final String LAMBDA = "lambda";
public static final String ALPHA = "alpha";
public static final String BETA = "beta";
private NormalGammaDistribution hyperprior;
private Parameter posteriorMean;
private Parameter posteriorBeta;
private Parameter posteriorExpectedPrecision;
double normalApproximationThreshold = 30;
private ArrayList<Double> dataValues;
private double[] currentParameters;
public NormalPeriodPriorDistribution(String name, boolean log, NormalGammaDistribution hyperprior){
super(name, log);
this.hyperprior = hyperprior;
posteriorBeta = new Parameter.Default(1);
posteriorMean = new Parameter.Default(1);
posteriorExpectedPrecision = new Parameter.Default(1);
addVariable(posteriorBeta);
addVariable(posteriorMean);
addVariable(posteriorExpectedPrecision);
}
public NormalPeriodPriorDistribution(String name, boolean log, double mu_0, double lambda_0,
double alpha_0, double beta_0){
this(name, log, new NormalGammaDistribution(mu_0, lambda_0, alpha_0, beta_0));
reset();
}
public void reset(){
dataValues = new ArrayList<Double>();
currentParameters = hyperprior.getParameters();
logL = 0;
}
// this returns the posterior predictive probability of the new value, and updates the total
public double calculateLogPosteriorProbability(double newValue, double minValue){
double out = calculateLogPosteriorPredictiveProbability(newValue);
if(minValue != Double.NEGATIVE_INFINITY){
out -= calculateLogPosteriorPredictiveCDF(minValue, true);
}
logL += out;
update(newValue);
return out;
}
public double calculateLogPosteriorCDF(double limit, boolean upper) {
return calculateLogPosteriorPredictiveCDF(limit, upper);
}
public double calculateLogPosteriorPredictiveProbability(double value){
double mean = currentParameters[0];
double sd = Math.sqrt(currentParameters[3]*(currentParameters[1]+1)
/(currentParameters[2]*currentParameters[1]));
double scaledValue = (value - mean)/sd;
double out;
if(2*currentParameters[2]<=normalApproximationThreshold) {
TDistributionImpl tDist = new TDistributionImpl(2 * currentParameters[2]);
out = Math.log(tDist.density(scaledValue));
} else {
out = NormalDistribution.logPdf(scaledValue, 0, 1);
}
return out;
}
public double calculateLogPosteriorPredictiveCDF(double value, boolean upperTail){
double mean = currentParameters[0];
double sd = Math.sqrt(currentParameters[3]*(currentParameters[1]+1)
/(currentParameters[2]*currentParameters[1]));
double scaledValue = (value - mean)/sd;
double out;
if(2*currentParameters[2]<=normalApproximationThreshold) {
TDistributionImpl tDist = new TDistributionImpl(2 * currentParameters[2]);
try {
out = upperTail ? Math.log(tDist.cumulativeProbability(-scaledValue))
: Math.log(tDist.cumulativeProbability(scaledValue));
} catch (MathException e){
throw new RuntimeException(e.toString());
}
} else {
out = upperTail ? NormalDistribution.standardCDF(-scaledValue, true) :
NormalDistribution.standardCDF(scaledValue, true);
}
return out;
}
private void update(double newData){
dataValues.add(newData);
double[] originalParameters=hyperprior.getParameters();
double lambda_0 = originalParameters[1];
double oldMu = currentParameters[0];
double oldLambda = currentParameters[1];
double oldAlpha = currentParameters[2];
double oldBeta = currentParameters[3];
double count = dataValues.size();
double newMu = (newData - oldMu)/(lambda_0 + count) + oldMu;
double newLambda = oldLambda + 1;
double newAlpha = oldAlpha + 0.5;
double newBeta = oldBeta + oldLambda*Math.pow(newData - oldMu, 2)/(2*(oldLambda+1));
posteriorMean.setParameterValue(0, newMu);
posteriorBeta.setParameterValue(0, newBeta);
posteriorExpectedPrecision.setParameterValue(0, newAlpha/newBeta);
currentParameters = new double[]{newMu, newLambda, newAlpha, newBeta};
}
public double calculateLogLikelihood(double[] values){
int count = values.length;
double[] infPredictiveDistributionParameters=hyperprior.getParameters();
double mu_0 = infPredictiveDistributionParameters[0];
double lambda_0 = infPredictiveDistributionParameters[1];
double alpha_0 = infPredictiveDistributionParameters[2];
double beta_0 = infPredictiveDistributionParameters[3];
double lambda_n = lambda_0 + count;
double alpha_n = alpha_0 + count/2;
double sum = 0;
for (Double infPeriod : values) {
sum += infPeriod;
}
double mean = sum/count;
double sumOfDifferences = 0;
for (Double infPeriod : values) {
sumOfDifferences += Math.pow(infPeriod-mean,2);
}
posteriorMean.setParameterValue(0, (lambda_0*mu_0 + sum)/(lambda_0 + count));
double beta_n = beta_0 + 0.5*sumOfDifferences
+ lambda_0*count*Math.pow(mean-mu_0, 2)/(2*(lambda_0+count));
posteriorBeta.setParameterValue(0, beta_n);
posteriorExpectedPrecision.setParameterValue(0, alpha_n/beta_n);
logL = GammaFunction.logGamma(alpha_n)
- GammaFunction.logGamma(alpha_0)
+ alpha_0*Math.log(beta_0)
- alpha_n*Math.log(beta_n)
+ 0.5*Math.log(lambda_0)
- 0.5*Math.log(lambda_n)
- (count/2)*Math.log(2*Math.PI);
return logL;
}
public LogColumn[] getColumns() {
ArrayList<LogColumn> columns = new ArrayList<LogColumn>(Arrays.asList(super.getColumns()));
columns.add(new LogColumn.Abstract(getModelName()+"_posteriorMean"){
protected String getFormattedValue() {
return String.valueOf(posteriorMean.getParameterValue(0));
}
});
columns.add(new LogColumn.Abstract(getModelName()+"_posteriorBeta"){
protected String getFormattedValue() {
return String.valueOf(posteriorBeta.getParameterValue(0));
}
});
columns.add(new LogColumn.Abstract(getModelName()+"_posteriorExpectedPrecision"){
protected String getFormattedValue() {
return String.valueOf(posteriorExpectedPrecision.getParameterValue(0));
}
});
return columns.toArray(new LogColumn[columns.size()]);
}
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public String getParserName() {
return NORMAL;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String id = (String) xo.getAttribute(ID);
double mu = xo.getDoubleAttribute(MU);
double lambda = xo.getDoubleAttribute(LAMBDA);
double alpha = xo.getDoubleAttribute(ALPHA);
double beta = xo.getDoubleAttribute(BETA);
boolean log;
log = xo.hasAttribute(LOG) ? xo.getBooleanAttribute(LOG) : false;
return new NormalPeriodPriorDistribution(id, log, mu, lambda, alpha, beta);
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
AttributeRule.newBooleanRule(LOG, true),
AttributeRule.newStringRule(ID, false),
AttributeRule.newDoubleRule(MU, false),
AttributeRule.newDoubleRule(LAMBDA, false),
AttributeRule.newDoubleRule(ALPHA, false),
AttributeRule.newDoubleRule(BETA, false)
};
public String getParserDescription() {
return "Calculates the probability of a set of doubles being drawn from the prior posterior distribution" +
"of a normal distribution of unknown mean and variance";
}
public Class getReturnType() {
return NormalPeriodPriorDistribution.class;
}
};
}