/* * MCMC.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.mcmc; import dr.inference.state.Factory; import dr.inference.state.StateLoader; import dr.inference.loggers.LogColumn; import dr.inference.loggers.Loggable; import dr.inference.loggers.Logger; import dr.inference.markovchain.MarkovChain; import dr.inference.markovchain.MarkovChainListener; import dr.inference.model.Likelihood; import dr.inference.model.Model; import dr.inference.operators.*; import dr.util.Identifiable; import dr.util.NumberFormatter; import dr.xml.Spawnable; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.PrintStream; /** * An MCMC analysis that estimates parameters of a probabilistic model. * * @author Alexei Drummond * @author Andrew Rambaut * @version $Id: MCMC.java,v 1.41 2005/07/11 14:06:25 rambaut Exp $ */ public class MCMC implements Identifiable, Spawnable, Loggable { public MCMC(String id) { this.id = id; } /** * Must be called before calling chain. * * @param options the options for this MCMC analysis * @param schedule operator schedule to be used in chain. * @param likelihood the likelihood for this MCMC * @param loggers an array of loggers to record output of this MCMC run */ public void init( MCMCOptions options, Likelihood likelihood, OperatorSchedule schedule, Logger[] loggers) { MCMCCriterion criterion = new MCMCCriterion(); criterion.setTemperature(options.getTemperature()); mc = new MarkovChain(likelihood, schedule, criterion, options.getFullEvaluationCount(), options.minOperatorCountForFullEvaluation(), options.getEvaluationTestThreshold(), options.useCoercion()); this.options = options; this.loggers = loggers; this.schedule = schedule; //initialize transients currentState = 0; if (Factory.INSTANCE != null) { for (MarkovChainListener listener : Factory.INSTANCE.getStateSaverChainListeners()) { mc.addMarkovChainListener(listener); } } } /** * Must be called before calling chain. * * @param chainlength chain length * @param likelihood the likelihood for this MCMC * @param operators an array of MCMC operators * @param loggers an array of loggers to record output of this MCMC run */ public void init(long chainlength, Likelihood likelihood, MCMCOperator[] operators, Logger[] loggers) { MCMCOptions options = new MCMCOptions(chainlength); MCMCCriterion criterion = new MCMCCriterion(); criterion.setTemperature(1); OperatorSchedule schedule = new SimpleOperatorSchedule(); for (MCMCOperator operator : operators) schedule.addOperator(operator); init(options, likelihood, schedule, loggers); } public MarkovChain getMarkovChain() { return mc; } public Logger[] getLoggers() { return loggers; } public MCMCOptions getOptions() { return options; } public OperatorSchedule getOperatorSchedule() { return schedule; } public void run() { chain(); } /** * This method actually initiates the MCMC analysis. */ public void chain() { stopping = false; currentState = 0; timer.start(); if (loggers != null) { for (Logger logger : loggers) { logger.startLogging(); } } if (!stopping) { long loadedState = 0; if (Factory.INSTANCE != null) { StateLoader initialStateLoader = Factory.INSTANCE.getInitialStateLoader(); if (initialStateLoader != null) { double[] savedLnL = new double[1]; initialStateLoader.loadState(mc, savedLnL); mc.setCurrentLength(loadedState); double lnL = mc.evaluate(); initialStateLoader.checkLoadState(savedLnL[0], lnL); } } mc.addMarkovChainListener(chainListener); long chainLength = getChainLength(); //this also potentially gets the new coercionDelay of a possibly increased chain length final long coercionDelay = getCoercionDelay(); //assume that dumped state has passed the coercionDelay //TODO: discuss whether we want to dump the coercionDelay or chainLength to file if (coercionDelay > loadedState) { mc.runChain(coercionDelay - loadedState, true); chainLength -= coercionDelay; } //if (coercionDelay > 0) { // Run the chain for coercionDelay steps with coercion disabled //mc.runChain(coercionDelay, true); //chainLength -= coercionDelay; // reset operator acceptance levels //GB: we are now restoring these; commenting out for now /*for (int i = 0; i < schedule.getOperatorCount(); i++) { schedule.getOperator(i).reset(); }*/ //} mc.runChain(chainLength, false); mc.terminateChain(); mc.removeMarkovChainListener(chainListener); } timer.stop(); } @Override public LogColumn[] getColumns() { return new LogColumn[] { new LogColumn() { @Override public void setLabel(String label) { } @Override public String getLabel() { return "time"; } @Override public void setMinimumWidth(int minimumWidth) { } @Override public int getMinimumWidth() { return 0; } @Override public String getFormatted() { return Double.toString(getTimer().toSeconds()); } } }; } private final MarkovChainListener chainListener = new MarkovChainListener() { // MarkovChainListener interface ******************************************* // for receiving messages from subordinate MarkovChain /** * Called to update the current model keepEvery states. */ @Override public void currentState(long state, MarkovChain markovChain, Model currentModel) { currentState = state; if (loggers != null) { for (Logger logger : loggers) { logger.log(state); } } } /** * Called when a new new best posterior state is found. */ @Override public void bestState(long state, MarkovChain markovChain, Model bestModel) { } /** * cleans up when the chain finishes (possibly early). */ @Override public void finished(long chainLength, MarkovChain markovChain) { currentState = chainLength; if (loggers != null) { for (Logger logger : loggers) { logger.stopLogging(); } } // OperatorAnalysisPrinter class can do the job now if (showOperatorAnalysis) { OperatorAnalysisPrinter.showOperatorAnalysis(System.out, getOperatorSchedule(), options.useCoercion()); } if (operatorAnalysisFile != null) { try { PrintStream out = new PrintStream(new FileOutputStream(operatorAnalysisFile)); OperatorAnalysisPrinter.showOperatorAnalysis(System.out, getOperatorSchedule(), options.useCoercion()); out.flush(); out.close(); } catch (IOException e) { e.printStackTrace(); } } // How should premature finish be flagged? } }; /** * @return the likelihood function. */ public Likelihood getLikelihood() { return mc.getLikelihood(); } /** * @return the timer. */ public dr.util.Timer getTimer() { return timer; } /** * @return the length of this analysis. */ public final long getChainLength() { return options.getChainLength(); } // TRANSIENT PUBLIC METHODS ***************************************** /** * @return the current state of the MCMC analysis. */ public final long getCurrentState() { return currentState; } /** * @return the progress (0 to 1) of the MCMC analysis. */ public final double getProgress() { return (double) currentState / (double) options.getChainLength(); } /** * @return true if this MCMC is currently adapting the operators. */ public final boolean isAdapting() { return isAdapting; } /** * Requests that the MCMC chain stop prematurely. */ public void pleaseStop() { stopping = true; mc.pleaseStop(); } /** * @return true if Markov chain is stopped */ public boolean isStopped() { return mc.isStopped(); } public boolean getSpawnable() { return spawnable; } private boolean spawnable = true; public void setSpawnable(boolean spawnable) { this.spawnable = spawnable; } //PRIVATE METHODS ***************************************** protected long getCoercionDelay() { long delay = options.getCoercionDelay(); if (delay < 0) { delay = (long)(options.getChainLength() / 100); } if (options.useCoercion()) return delay; for (int i = 0; i < schedule.getOperatorCount(); i++) { MCMCOperator op = schedule.getOperator(i); if (op instanceof CoercableMCMCOperator) { if (((CoercableMCMCOperator) op).getMode() == CoercionMode.COERCION_ON) return delay; } } return -1; } public void setShowOperatorAnalysis(boolean soa) { showOperatorAnalysis = soa; } public void setOperatorAnalysisFile(File operatorAnalysisFile) { this.operatorAnalysisFile = operatorAnalysisFile; } public String getId() { return id; } public void setId(String id) { this.id = id; } // PRIVATE TRANSIENTS //private FileLogger operatorLogger = null; protected final boolean isAdapting = true; protected boolean stopping = false; protected boolean showOperatorAnalysis = true; protected File operatorAnalysisFile = null; protected final dr.util.Timer timer = new dr.util.Timer(); protected long currentState = 0; //private int stepsPerReport = 1000; protected final NumberFormatter formatter = new NumberFormatter(8); /** * this markov chain does most of the work. */ protected MarkovChain mc; /** * the options of this MCMC analysis */ protected MCMCOptions options; protected Logger[] loggers; protected OperatorSchedule schedule; private String id = null; }