/*
* MarginalLikelihoodEstimator.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.mcmc;
import com.sun.tools.javac.util.Options;
import dr.inference.loggers.Logger;
import dr.inference.loggers.MCLogger;
import dr.inference.markovchain.MarkovChain;
import dr.inference.markovchain.MarkovChainListener;
import dr.inference.model.Model;
import dr.inference.model.PathLikelihood;
import dr.inference.operators.*;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.Identifiable;
import dr.xml.*;
import org.apache.commons.math.MathException;
import org.apache.commons.math.distribution.BetaDistributionImpl;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* @author Andrew Rambaut
* @author Alex Alekseyenko
* @author Marc Suchard
* @author Guy Baele
*/
public class MarginalLikelihoodEstimator implements Runnable, Identifiable, Citable {
public MarginalLikelihoodEstimator(String id, int chainLength, int burninLength, int pathSteps, double[] fixedRunValues,
// boolean linear, boolean lacing,
PathScheme scheme,
PathLikelihood pathLikelihood,
OperatorSchedule schedule,
List<MCLogger> loggers) {
this.id = id;
this.chainLength = chainLength;
this.pathSteps = pathSteps;
this.scheme = scheme;
this.schedule = schedule;
this.fixedRunValues = fixedRunValues;
// deprecated
// this.linear = (scheme == PathScheme.LINEAR);
// this.lacing = false; // Was not such a good idea
this.burninLength = burninLength;
MCMCCriterion criterion = new MCMCCriterion();
pathDelta = 1.0 / pathSteps;
pathParameter = 1.0;
this.pathLikelihood = pathLikelihood;
pathLikelihood.setPathParameter(pathParameter);
mc = new MarkovChain(pathLikelihood, schedule, criterion, 0, 0, 0.0, true);
this.loggers = loggers;
}
private void setDefaultBurnin() {
if (burninLength == -1) {
burnin = (int) (0.1 * chainLength);
} else {
burnin = burninLength;
}
}
public void integrate(Integrator scheme) {
setDefaultBurnin();
mc.setCurrentLength(burnin);
scheme.init();
((CombinedOperatorSchedule) schedule).reset();
for (pathParameter = scheme.nextPathParameter(); pathParameter >= 0; pathParameter = scheme.nextPathParameter()) {
pathLikelihood.setPathParameter(pathParameter);
reportIteration(pathParameter, chainLength, burnin, scheme.pathSteps, scheme.step);
for (int i = 0; i < schedule.getOperatorCount(); ++i) {
MCMCOperator operator = schedule.getOperator(i);
if (operator instanceof PathDependentOperator) {
((PathDependentOperator)operator).setPathParameter(pathParameter);
}
}
long cl = mc.getCurrentLength();
mc.setCurrentLength(0);
mc.runChain(burnin, false/*, 0*/);
mc.setCurrentLength(cl);
mc.runChain(chainLength, false);
if (SHOW_OPERATOR_ANALYSIS) {
OperatorAnalysisPrinter.showOperatorAnalysis(System.out, schedule, false);
}
((CombinedOperatorSchedule) schedule).reset();
}
}
public abstract class Integrator {
protected int step;
protected int pathSteps;
protected Integrator(int pathSteps) {
this.pathSteps = pathSteps;
}
public void init() {
step = 0;
}
abstract double nextPathParameter();
}
public class FixedThetaRun extends Integrator {
private double[] value;
public FixedThetaRun(double[] value) {
super(value.length);
this.value = value;
}
double nextPathParameter() {
if (step < value.length) {
step++;
return value[step-1];
} else {
return -1.0;
}
}
}
public class LinearIntegrator extends Integrator {
public LinearIntegrator(int pathSteps) {
super(pathSteps);
}
double nextPathParameter() {
if (step > pathSteps) {
return -1;
}
double pathParameter = 1.0 - (double)step / (double)(pathSteps);
step = step + 1;
return pathParameter;
}
}
public class SigmoidIntegrator extends Integrator {
private double alpha;
public SigmoidIntegrator(double alpha, int pathSteps) {
super(pathSteps);
this.alpha = alpha;
}
double nextPathParameter() {
if (step == 0) {
step++;
return 1.0;
} else if (step == pathSteps) {
step++;
return 0.0;
} else if (step > pathSteps) {
return -1.0;
} else {
double xvalue = ((pathSteps - step)/((double)pathSteps)) - 0.5;
step++;
return Math.exp(alpha*xvalue)/(Math.exp(alpha*xvalue) + Math.exp(-alpha*xvalue));
}
}
}
public class BetaQuantileIntegrator extends Integrator {
private double alpha;
public BetaQuantileIntegrator(double alpha, int pathSteps) {
super(pathSteps);
this.alpha = alpha;
}
double nextPathParameter() {
if (step > pathSteps)
return -1;
double result = Math.pow((pathSteps - step)/((double)pathSteps), 1.0/alpha);
step++;
return result;
}
}
public class BetaIntegrator extends Integrator {
private BetaDistributionImpl betaDistribution;
public BetaIntegrator(double alpha, double beta, int pathSteps) {
super(pathSteps);
this.betaDistribution = new BetaDistributionImpl(alpha, beta);
}
double nextPathParameter() {
if (step > pathSteps)
return -1;
if (step == 0) {
step += 1;
return 1.0;
} else if (step + 1 < pathSteps) {
double ratio = (double) step / (double) (pathSteps - 1);
try {
step += 1;
return 1.0 - betaDistribution.inverseCumulativeProbability(ratio);
} catch (MathException e) {
e.printStackTrace();
}
}
step += 1;
return 0.0;
}
}
public class GeometricIntegrator extends Integrator {
public GeometricIntegrator(int pathSteps) {
super(pathSteps);
}
double nextPathParameter() {
if (step > pathSteps) {
return -1;
}
if (step == pathSteps) { //pathSteps instead of pathSteps - 1
step += 1;
return 0;
}
step += 1;
return Math.pow(2, -(step - 1));
}
}
/*public void linearIntegration() {
setDefaultBurnin();
mc.setCurrentLength(0);
for (int step = 0; step < pathSteps; step++) {
pathLikelihood.setPathParameter(pathParameter);
reportIteration(pathParameter, chainLength, burnin);
//mc.runChain(chainLength + burnin, false, 0);
mc.runChain(chainLength + burnin, false);
pathParameter -= pathDelta;
}
pathLikelihood.setPathParameter(0.0);
reportIteration(pathParameter, chainLength, burnin);
//mc.runChain(chainLength + burnin, false, 0);
mc.runChain(chainLength + burnin, false);
}*/
/*public void betaIntegration(double alpha, double beta) {
setDefaultBurnin();
mc.setCurrentLength(0);
BetaDistributionImpl betaDistribution = new BetaDistributionImpl(alpha, beta);
for (int step = 0; step < pathSteps; step++) {
if (step == 0) {
pathParameter = 1.0;
} else if (step + 1 < pathSteps) {
double ratio = (double) step / (double) (pathSteps - 1);
try {
pathParameter = 1.0 - betaDistribution.inverseCumulativeProbability(ratio);
} catch (MathException e) {
e.printStackTrace();
}
} else {
pathParameter = 0.0;
}
pathLikelihood.setPathParameter(pathParameter);
reportIteration(pathParameter, chainLength, burnin);
//mc.runChain(chainLength + burnin, false, 0);
mc.runChain(chainLength + burnin, false);
(new OperatorAnalysisPrinter(schedule)).showOperatorAnalysis(System.out);
((CombinedOperatorSchedule) schedule).reset();
}
}*/
private void reportIteration(double pathParameter, long chainLength, long burnin, long totalSteps, long steps) {
if (scheme == PathScheme.FIXED) {
System.out.println("Attempting fixed theta ("+steps+"/" + (totalSteps) +") = " + pathParameter + " for " + chainLength + " iterations + " + burnin + " burnin.");
} else {
System.out.println("Attempting theta ("+steps+"/" + (totalSteps+1) +") = " + pathParameter + " for " + chainLength + " iterations + " + burnin + " burnin.");
}
}
public void run() {
for (MCLogger logger : loggers) {
logger.startLogging();
}
mc.addMarkovChainListener(chainListener);
/*switch (scheme) {
case LINEAR:
linearIntegration();
break;
case OLD_GEOMETRIC:
geometricIntegration();
break;
case ONE_SIDED_BETA:
betaIntegration(1.0, betaFactor);
break;
case BETA:
betaIntegration(alphaFactor, betaFactor);
break;
default:
throw new RuntimeException("Illegal path scheme");
}*/
switch (scheme) {
case FIXED:
integrate(new FixedThetaRun(fixedRunValues));
break;
case LINEAR:
integrate(new LinearIntegrator(pathSteps));
break;
case GEOMETRIC:
integrate(new GeometricIntegrator(pathSteps));
break;
case ONE_SIDED_BETA:
integrate(new BetaIntegrator(1.0, betaFactor, pathSteps));
break;
case BETA:
integrate(new BetaIntegrator(alphaFactor, betaFactor, pathSteps));
break;
case BETA_QUANTILE:
integrate(new BetaQuantileIntegrator(alphaFactor, pathSteps));
break;
case SIGMOID:
integrate(new SigmoidIntegrator(alphaFactor, pathSteps));
break;
default:
throw new RuntimeException("Illegal path scheme");
}
mc.removeMarkovChainListener(chainListener);
}
private final MarkovChainListener chainListener = new MarkovChainListener() {
// MarkovChainListener interface *******************************************
// for receiving messages from subordinate MarkovChain
/**
* Called to update the current model keepEvery states.
*/
@Override
public void currentState(long state, MarkovChain markovChain, Model currentModel) {
currentState = state;
if (currentState >= burnin) {
for (MCLogger logger : loggers) {
logger.log(state);
}
}
}
/**
* Called when a new new best posterior state is found.
*/
@Override
public void bestState(long state, MarkovChain markovChain, Model bestModel) {
currentState = state;
}
/**
* cleans up when the chain finishes (possibly early).
*/
@Override
public void finished(long chainLength, MarkovChain markovChain) {
currentState = chainLength;
OperatorAnalysisPrinter.showOperatorAnalysis(System.out, schedule, false);
// logger.log(currentState);
for (MCLogger logger : loggers) {
logger.stopLogging();
}
}
};
@Override
public Citation.Category getCategory() {
return Citation.Category.FRAMEWORK;
}
@Override
public String getDescription() {
return "Marginal likelihood estimation using path sampling / stepping-stone sampling (first 2 citations) and generalized stepping-stone sampling (3rd citation)";
}
@Override
public List<Citation> getCitations() {
return Arrays.asList(new Citation(
new Author[]{
new Author("G", "Baele"),
new Author("P", "Lemey"),
new Author("T", "Bedford"),
new Author("A", "Rambaut"),
new Author("MA", "Suchard"),
new Author("AV", "Alekseyenko")
},
"Improving the accuracy of demographic and molecular clock model comparison while accommodating phylogenetic uncertainty",
2012,
"Mol. Biol. Evol.",
29,
2157, 2167,
Citation.Status.PUBLISHED
),
new Citation(
new Author[]{
new Author("G", "Baele"),
new Author("WLS", "Li"),
new Author("AJ", "Drummond"),
new Author("MA", "Suchard"),
new Author("P", "Lemey")
},
"Accurate model selection of relaxed molecular clocks in Bayesian phylogenetics",
2013,
"Mol. Biol. Evol.",
30,
239, 243,
Citation.Status.PUBLISHED
),
new Citation(
new Author[]{
new Author("G", "Baele"),
new Author("P", "Lemey"),
new Author("MA", "Suchard")
},
"Genealogical working distributions for Bayesian model testing with phylogenetic uncertainty",
2016,
"Syst. Biol.",
65,
250, 264,
Citation.Status.PUBLISHED
)
);
}
// TRANSIENT PUBLIC METHODS *****************************************
/**
* @return the current state of the MCMC analysis.
*/
public boolean getSpawnable() {
return spawnable;
}
private boolean spawnable = true;
public void setSpawnable(boolean spawnable) {
this.spawnable = spawnable;
}
public void setAlphaFactor(double alpha) {
alphaFactor = alpha;
}
public void setBetaFactor(double beta) {
betaFactor = beta;
}
public double getAlphaFactor() {
return alphaFactor;
}
public double getBetaFactor() {
return betaFactor;
}
//PRIVATE METHODS *****************************************
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public String getParserName() {
return MARGINAL_LIKELIHOOD_ESTIMATOR;
}
/**
* @return a tree object based on the XML element it was passed.
*/
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
PathLikelihood pathLikelihood = (PathLikelihood) xo.getChild(PathLikelihood.class);
List<MCLogger> loggerList = new ArrayList<MCLogger>();
for (int i = 0; i < xo.getChildCount(); i++) {
if (xo.getChild(i) instanceof MCLogger) {
loggerList.add((MCLogger)xo.getChild(i));
}
}
//MCLogger logger = (MCLogger) xo.getChild(MCLogger.class);
int chainLength = xo.getIntegerAttribute(CHAIN_LENGTH);
int pathSteps = -1;
double[] fixedRunValues = null;
if (xo.hasAttribute(PATH_STEPS)) {
pathSteps = xo.getIntegerAttribute(PATH_STEPS);
} else if (xo.hasAttribute(FIXED_VALUE)) {
fixedRunValues = xo.getDoubleArrayAttribute(FIXED_VALUE);
} else {
throw new RuntimeException("Either a number of path steps or predefined beta values need to be provided.");
}
if (xo.hasAttribute(PRINT_OPERATOR_ANALYSIS)) {
SHOW_OPERATOR_ANALYSIS = xo.getBooleanAttribute(PRINT_OPERATOR_ANALYSIS);
}
int burninLength = -1;
if (xo.hasAttribute(BURNIN)) {
burninLength = xo.getIntegerAttribute(BURNIN);
}
int prerunLength = -1;
if (xo.hasAttribute(PRERUN)) {
prerunLength = xo.getIntegerAttribute(PRERUN);
}
// deprecated
boolean linear = xo.getAttribute(LINEAR, true);
// boolean lacing = xo.getAttribute(LACING,false);
PathScheme scheme;
if (linear) {
scheme = PathScheme.LINEAR;
} else {
scheme = PathScheme.GEOMETRIC;
}
// new approach
if (xo.hasAttribute(PATH_SCHEME)) { // change to: getAttribute once deprecated approach removed
scheme = PathScheme.parseFromString(xo.getAttribute(PATH_SCHEME, PathScheme.LINEAR.getText()));
}
for (int i = 0; i < xo.getChildCount(); i++) {
Object child = xo.getChild(i);
if (child instanceof Logger) {
}
}
CombinedOperatorSchedule os = new CombinedOperatorSchedule();
XMLObject mcmcXML = xo.getChild(MCMC);
for (int i = 0; i < mcmcXML.getChildCount(); ++i) {
if (mcmcXML.getChild(i) instanceof MCMC) {
MCMC mcmc = (MCMC) mcmcXML.getChild(i);
if (prerunLength > 0) {
java.util.logging.Logger.getLogger("dr.inference").info("Path Sampling Marginal Likelihood Estimator:\n\tEquilibrating chain " + mcmc.getId() + " for " + prerunLength + " iterations.");
for (Logger log : mcmc.getLoggers()) { // Stop the loggers, so nothing gets written to normal output
log.stopLogging();
}
mcmc.getMarkovChain().runChain(prerunLength, false);
}
if (xo.getChild(OperatorSchedule.class) != null) {
os.addOperatorSchedule((OperatorSchedule)xo.getChild(OperatorSchedule.class));
} else {
os.addOperatorSchedule(mcmc.getOperatorSchedule());
}
}
}
if (os.getScheduleCount() == 0) {
System.err.println("Error: no mcmc objects provided in construction. Bayes Factor estimation will likely fail.");
}
MarginalLikelihoodEstimator mle = new MarginalLikelihoodEstimator(MARGINAL_LIKELIHOOD_ESTIMATOR, chainLength,
burninLength, pathSteps, fixedRunValues, scheme, pathLikelihood, os, loggerList);
if (!xo.getAttribute(SPAWN, true))
mle.setSpawnable(false);
if (xo.hasAttribute(ALPHA)) {
mle.setAlphaFactor(xo.getAttribute(ALPHA, 0.5));
}
if (xo.hasAttribute(BETA)) {
mle.setBetaFactor(xo.getAttribute(BETA, 0.5));
}
String alphaBetaText = "";
if (scheme == PathScheme.ONE_SIDED_BETA) {
alphaBetaText += "(1," + mle.getBetaFactor() + ")";
} else if (scheme == PathScheme.BETA) {
alphaBetaText += "(" + mle.getAlphaFactor() + "," + mle.getBetaFactor() + ")";
} else if (scheme == PathScheme.BETA_QUANTILE) {
alphaBetaText += "(" + mle.getAlphaFactor() + ")";
} else if (scheme == PathScheme.SIGMOID) {
alphaBetaText += "(" + mle.getAlphaFactor() + ")";
}
java.util.logging.Logger.getLogger("dr.inference").info("\nCreating the Marginal Likelihood Estimator chain:" +
"\n chainLength=" + chainLength +
"\n pathSteps=" + pathSteps +
"\n pathScheme=" + scheme.getText() + alphaBetaText); //+
//"\n If you use these results, please cite:" +
//"\n Guy Baele, Philippe Lemey, Trevor Bedford, Andrew Rambaut, Marc A. Suchard, and Alexander V. Alekseyenko." +
//"\n 2012. Improving the accuracy of demographic and molecular clock model comparison while accommodating " +
//"\n phylogenetic uncertainty. Mol. Biol. Evol. 29(9):2157-2167." +
//"\n and " +
//"\n Guy Baele, Wai Lok Sibon Li, Alexei J. Drummond, Marc A. Suchard, and Philippe Lemey. 2013." +
//"\n Accurate model selection of relaxed molecular clocks in Bayesian phylogenetics. Mol. Biol. Evol. 30(2):239-243.\n");
return mle;
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "This element returns an MCMC chain and runs the chain as a side effect.";
}
public Class getReturnType() {
return MCMC.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
AttributeRule.newIntegerRule(CHAIN_LENGTH),
AttributeRule.newIntegerRule(PATH_STEPS, true),
AttributeRule.newIntegerRule(BURNIN, true),
AttributeRule.newIntegerRule(PRERUN, true),
AttributeRule.newBooleanRule(LINEAR, true),
AttributeRule.newBooleanRule(LACING, true),
AttributeRule.newBooleanRule(SPAWN, true),
AttributeRule.newBooleanRule(PRINT_OPERATOR_ANALYSIS, true),
AttributeRule.newStringRule(PATH_SCHEME, true),
AttributeRule.newDoubleArrayRule(FIXED_VALUE, true),
AttributeRule.newDoubleRule(ALPHA, true),
AttributeRule.newDoubleRule(BETA, true),
new ElementRule(MCMC,
new XMLSyntaxRule[]{new ElementRule(MCMC.class, 1, Integer.MAX_VALUE)}, false),
//new ElementRule(MCMC.class),
new ElementRule(PathLikelihood.class),
new ElementRule(MCLogger.class, 1, Integer.MAX_VALUE)
};
};
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
enum PathScheme {
FIXED("fixed"),
LINEAR("linear"),
GEOMETRIC("geometric"),
BETA("beta"),
ONE_SIDED_BETA("oneSidedBeta"),
BETA_QUANTILE("betaQuantile"),
SIGMOID("sigmoid");
PathScheme(String text) {
this.text = text;
}
public String getText() {
return text;
}
private final String text;
public static PathScheme parseFromString(String text) {
for (PathScheme scheme : PathScheme.values()) {
if (scheme.getText().compareToIgnoreCase(text) == 0)
return scheme;
}
return null;
}
}
// PRIVATE TRANSIENTS
/**
* this markov chain does most of the work.
*/
private final MarkovChain mc;
private OperatorSchedule schedule;
private String id = null;
private long currentState;
private final long chainLength;
private long burnin;
private final long burninLength;
private int pathSteps;
// private final boolean linear;
// private final boolean lacing;
private final PathScheme scheme;
private double alphaFactor = 0.5;
private double betaFactor = 0.5;
private double[] fixedRunValues;
private final double pathDelta;
private double pathParameter;
private final List<MCLogger> loggers;
private final PathLikelihood pathLikelihood;
public static final String MARGINAL_LIKELIHOOD_ESTIMATOR = "marginalLikelihoodEstimator";
public static final String CHAIN_LENGTH = "chainLength";
public static final String PATH_STEPS = "pathSteps";
public static final String FIXED = "fixed";
public static final String LINEAR = "linear";
public static final String LACING = "lacing";
public static final String SPAWN = "spawn";
public static final String BURNIN = "burnin";
public static final String MCMC = "samplers";
public static final String PATH_SCHEME = "pathScheme";
public static final String FIXED_VALUE = "fixedValues";
public static final String ALPHA = "alpha";
public static final String BETA = "beta";
public static final String PRERUN = "prerun";
public static final String PRINT_OPERATOR_ANALYSIS = "printOperatorAnalysis";
private static boolean SHOW_OPERATOR_ANALYSIS = false;
}