/*
* MarginalLikelihoodAnalysis.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.trace;
import dr.math.LogTricks;
import dr.math.MathUtils;
import dr.util.TaskListener;
import java.util.ArrayList;
import java.util.List;
/**
* @author Marc Suchard
* @author Alexei Drummond
* <p/>
* Source translated from model_P.c (a component of BAli-Phy by Benjamin Redelings and Marc Suchard
*/
public class MarginalLikelihoodAnalysis {
private final String traceName;
private final List<Double> sample;
private final int burnin;
private final String analysisType; // "harmonic" for harmonic mean, "smoothed" for smoothed harmonic mean, "aicm" for AICM, "arithmetic" for arithmetic mean
private final int bootstrapLength;
private boolean marginalLikelihoodCalculated = false;
private double logMarginalLikelihood;
private double bootstrappedSE;
// public MarginalLikelihoodAnalysis(double[] sample, String traceName, int burnin) {
// this(sample, traceName, burnin, false, 1000);
// }
public String getTraceName() {
return traceName;
}
public int getBurnin() {
return burnin;
}
/**
* Constructor
*
* @param sample
* @param traceName used for 'toString' display purposes only
* @param burnin used for 'toString' display purposes only
* @param analysisType
* @param bootstrapLength a value of zero will turn off bootstrapping
*/
public MarginalLikelihoodAnalysis(List<Double> sample, String traceName, int burnin, String analysisType, int bootstrapLength) {
this.sample = sample;
this.traceName = traceName;
this.burnin = burnin;
this.analysisType = analysisType;
this.bootstrapLength = bootstrapLength;
// System.err.println("setting burnin to "+burnin);
}
public double calculateLogMarginalLikelihood(List<Double> sample) {
if (analysisType.equals("aicm")) {
return logMarginalLikelihoodAICM(sample);
} else if (analysisType.equals("smoothed")) {
return logMarginalLikelihoodSmoothed(sample);
} else if (analysisType.equals("arithmetic")) {
return logMarginalLikelihoodArithmetic(sample);
} else {
return logMarginalLikelihoodHarmonic(sample);
}
}
/**
* Calculates the log marginal likelihood of a model using the arithmetic mean estimator
*
* @param v a posterior sample of logLikelihoods
* @return the log marginal likelihood
*/
public double logMarginalLikelihoodArithmetic(List<Double> v) {
int size = v.size();
double sum = LogTricks.logZero;
for (int i = 0; i < size; i++) {
if (!Double.isNaN(v.get(i)) && !Double.isInfinite(v.get(i))) {
sum = LogTricks.logSum(sum, v.get(i));
} else {
size--;
}
}
return sum - StrictMath.log(size);
}
/**
* Calculates the log marginal likelihood of a model using Newton and Raftery's harmonic mean estimator
*
* @param v a posterior sample of logLikelihoods
* @return the log marginal likelihood
*/
public double logMarginalLikelihoodHarmonic(List<Double> v) {
double sum = 0;
final int size = v.size();
for (int i = 0; i < size; i++)
sum += v.get(i);
double denominator = LogTricks.logZero;
for (int i = 0; i < size; i++)
denominator = LogTricks.logSum(denominator, sum - v.get(i));
return sum - denominator + StrictMath.log(size);
}
/**
* Calculates the AICM of a model using method-of-moments from Raftery et al. (2007)
*
* @param v a posterior sample of logLikelihoods
* @return the AICM (lower values are better)
*/
public double logMarginalLikelihoodAICM(List<Double> v) {
double sum = 0;
final int size = v.size();
for (int i = 0; i < size; i++)
sum += v.get(i);
double mean = sum / (double) size;
double var = 0;
for (int i = 0; i < size; i++)
var += (v.get(i)-mean) * (v.get(i)-mean);
var /= (double) size - 1;
return 2 * var - 2 * mean;
}
public void calculate() {
logMarginalLikelihood = calculateLogMarginalLikelihood(sample);
if (bootstrapLength > 1) {
final int sampleLength = sample.size();
List<Double> bsSample = new ArrayList<Double>();
Double[] bootstrappedLogML = new Double[bootstrapLength];
double sum = 0;
double progress = 0.0;
double delta = 1.0 / bootstrapLength;
//System.err.println("HME = " + logMarginalLikelihood);
for (int i = 0; i < bootstrapLength; i++) {
// if (i % 10 == 0) {
// System.err.println((i+1) + "/" + bootstrapLength);
// }
fireProgress(progress);
progress += delta;
int[] indices = MathUtils.sampleIndicesWithReplacement(sampleLength);
for (int k = 0; k < sampleLength; k++) {
bsSample.add(sample.get(indices[k]));
}
bootstrappedLogML[i] = calculateLogMarginalLikelihood(bsSample);
sum += bootstrappedLogML[i];
bsSample.clear();
}
sum /= bootstrapLength;
double bootstrappedAverage = sum;
// Summarize bootstrappedLogML
double var = 0;
for (int i = 0; i < bootstrapLength; i++) {
var += (bootstrappedLogML[i] - bootstrappedAverage) *
(bootstrappedLogML[i] - bootstrappedAverage);
}
var /= (bootstrapLength - 1.0);
bootstrappedSE = Math.sqrt(var);
}
fireProgress(1.0);
marginalLikelihoodCalculated = true;
}
/**
* Calculates the log marginal likelihood of a model using Newton and Raftery's smoothed estimator
*
* @param v a posterior sample of logLikelihood
* @param delta proportion of pseudo-samples from the prior
* @param Pdata current estimate of the log marginal likelihood
* @return the log marginal likelihood
*/
@SuppressWarnings({"SuspiciousNameCombination"})
public double logMarginalLikelihoodSmoothed(List<Double> v, double delta, double Pdata) {
final double logDelta = StrictMath.log(delta);
final double logInvDelta = StrictMath.log(1.0 - delta);
final int n = v.size();
final double logN = StrictMath.log(n);
final double offset = logInvDelta - Pdata;
double bottom = logN + logDelta - logInvDelta;
double top = bottom + Pdata;
for (int i = 0; i < n; i++) {
double weight = -LogTricks.logSum(logDelta, offset + v.get(i));
top = LogTricks.logSum(top, weight + v.get(i));
bottom = LogTricks.logSum(bottom, weight);
}
return top - bottom;
}
public double getLogMarginalLikelihood() {
if (!marginalLikelihoodCalculated) {
calculate();
}
return logMarginalLikelihood;
}
public double getBootstrappedSE() {
if (!marginalLikelihoodCalculated) {
calculate();
}
return bootstrappedSE;
}
public String toString() {
StringBuilder sb = new StringBuilder();
if (analysisType.equals("smoothed")) {
sb.append("log marginal likelihood (using smoothed harmonic mean)");
}
else if (analysisType.equals("aicm")) {
sb.append("AICM");
}
else if (analysisType.equals("arithmetic")) {
sb.append("log marginal likelihood (using arithmetic mean)");
}
else {
sb.append("log marginal likelihood (using harmonic mean)");
}
sb.append(" from ")
.append(traceName)
.append(" = ")
.append(String.format("%5.4f", getLogMarginalLikelihood()));
if (bootstrapLength > 1) {
sb.append(" +/- ")
.append(String.format("%5.4f", getBootstrappedSE()));
} else {
sb.append(" ");
}
sb.append(" burnin=").append(burnin);
if (bootstrapLength > 1)
sb.append(" replicates=").append(bootstrapLength);
// sb.append("\n");
return sb.toString();
}
public double logMarginalLikelihoodSmoothed(List<Double> v) {
final double delta = 0.01; // todo make class adjustable by accessor/setter
// Start with harmonic estimator as first guess
double Pdata = logMarginalLikelihoodHarmonic(v);
double deltaP = 1.0;
int iterations = 0;
double dx;
final double tolerance = 1E-3; // todo make class adjustable by accessor/setter
while (Math.abs(deltaP) > tolerance) {
double g1 = logMarginalLikelihoodSmoothed(v, delta, Pdata) - Pdata;
double Pdata2 = Pdata + g1;
dx = g1 * 10.0;
double g2 = logMarginalLikelihoodSmoothed(v, delta, Pdata + dx) - (Pdata + dx);
double dgdx = (g2 - g1) / dx; // find derivative at Pdata
double Pdata3 = Pdata - g1 / dgdx; // find new evaluation point
if (Pdata3 < 2.0 * Pdata || Pdata3 > 0 || Pdata3 > 0.5 * Pdata) // step is too large
Pdata3 = Pdata + 10.0 * g1;
double g3 = logMarginalLikelihoodSmoothed(v, delta, Pdata3) - Pdata3;
// Try to do a Newton's method step
if (Math.abs(g3) <= Math.abs(g2) && ((g3 > 0) || (Math.abs(dgdx) > 0.01))) {
deltaP = Pdata3 - Pdata;
Pdata = Pdata3;
} // otherwise try to go 10 times as far as one step
else if (Math.abs(g2) <= Math.abs(g1)) {
Pdata2 += g2;
deltaP = Pdata2 - Pdata;
Pdata = Pdata2;
} // otherwise go just one step
else {
deltaP = g1;
Pdata += g1;
}
iterations++;
if (iterations > 400) { // todo make class adjustable by acessor/setter
System.err.println("Probabilities are not converging!!!"); // todo should throw exception
return LogTricks.logZero;
}
}
return Pdata;
}
private TaskListener listener = null;
public void setTaskListener(TaskListener listener) {
this.listener = listener;
}
private void fireProgress(double progress) {
if (listener != null) {
listener.progress(progress);
}
}
}