/*
* MCMCMC.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.mcmcmc;
import dr.inference.loggers.LogFormatter;
import dr.inference.loggers.Logger;
import dr.inference.loggers.MCLogger;
import dr.inference.markovchain.MarkovChain;
import dr.inference.markovchain.MarkovChainListener;
import dr.inference.mcmc.MCMC;
import dr.inference.mcmc.MCMCCriterion;
import dr.inference.mcmc.MCMCOptions;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.operators.CoercableMCMCOperator;
import dr.inference.operators.CoercionMode;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.OperatorSchedule;
import dr.math.MathUtils;
import dr.util.NumberFormatter;
import java.util.Collections;
import java.util.List;
/**
* An MCMC analysis that estimates parameters of a probabilistic model.
*
* @author Andrew Rambaut
* @version $Id: ParallelMCMC.java,v 1.12 2005/01/10 10:56:59 rambaut Exp $
*/
public class MCMCMC implements Runnable {
public final static boolean DEBUG = false;
public MCMCMC(MCMC[] mcmcs, MCMCMCOptions mcmcmcOptions) {
this.mcmcmcOptions = mcmcmcOptions;
if (mcmcmcOptions.getChainTemperatures()[0] != 1.0) {
throw new RuntimeException("The first chain in the array should be cold (temperature = 1.0)");
}
coldChain = 0;
this.mcmcOptions = mcmcs[coldChain].getOptions();
// Get all the loggers out of all the chains. We will only use the
// loggers of the cold chain but we need to swap the formatters around
// so that which every chain is cold always writes to the same destination.
mcLoggers = new MCLogger[mcmcs.length][];
for (int i = 0; i < mcmcs.length; i++) {
Logger[] loggers = mcmcs[i].getLoggers();
mcLoggers[i] = new MCLogger[loggers.length];
for (int j = 0; j < loggers.length; j++) {
mcLoggers[i][j] = (MCLogger) loggers[j];
}
if (mcLoggers[i] == null) {
throw new RuntimeException("There are no loggers in the MCMC chains.");
}
}
// Get all the operator schedules. The tuning values of these must be swapped
// around as the temperatures are swapped.
schedules = new OperatorSchedule[mcmcs.length];
for (int i = 0; i < schedules.length; i++) {
schedules[i] = mcmcs[i].getOperatorSchedule();
}
chains = new MarkovChain[mcmcs.length];
chains[0] = mcmcs[0].getMarkovChain();
for (int i = 1; i < chains.length; i++) {
chains[i] = mcmcs[i].getMarkovChain();
MCMCCriterion acceptor = ((MCMCCriterion) chains[i].getAcceptor());
acceptor.setTemperature(mcmcmcOptions.getChainTemperatures()[i]);
}
}
public void run() {
currentState = 0;
timer.start();
// if (isPreBurninNeeded()) {
// long preBurnin = mcmcOptions.getCoercionDelay();
// if (preBurnin > 0) {
// MarkovChainListener burninListener = new BurninListener(preBurnin);
//
// chains[coldChain].addMarkovChainListener(burninListener);
// runChains(preBurnin, true);
// chains[coldChain].removeMarkovChainListener(burninListener);
// resetChains();
// }
// }
MCLogger[] coldChainLoggers = mcLoggers[coldChain];
List<LogFormatter>[] logFormatters = new List[coldChainLoggers.length];
for (int i = 0; i < coldChainLoggers.length; i++) {
// Start the logging for the cold chain
coldChainLoggers[i].startLogging();
// Now get the formatters (destinations) for the cold chains coldChainLoggers
logFormatters[i] = coldChainLoggers[i].getFormatters();
}
// Set the other chains to have null log formatters...
for (int j = 0; j < mcLoggers.length; j++) {
if (j != coldChain) {
for (int i = 0; i < mcLoggers[j].length; i++) {
mcLoggers[j][i].setFormatters(Collections.EMPTY_LIST);
}
}
}
chains[coldChain].addMarkovChainListener(chainListener);
MCMCMCRunner[] threads = new MCMCMCRunner[chains.length];
for (int i = 0; i < chains.length; i++) {
threads[i] = new MCMCMCRunner(chains[i], mcmcmcOptions.getSwapChainsEvery(), getChainLength(), false);
threads[i].start();
}
while (chains[coldChain].getCurrentLength() < getChainLength()) {
// wait for all the threads to complete their alloted chain length
boolean allDone;
do {
try {
Thread.sleep(10);
} catch (InterruptedException e) {
//
}
allDone = true;
for (int i = 0; i < chains.length; i++) {
if (!threads[i].isChainDone()) {
allDone = false;
}
}
} while (!allDone);
if (chains[coldChain].getCurrentLength() < getChainLength()) {
int oldColdChain = coldChain;
// attempt to swap two chains' temperatures
coldChain = swapChainTemperatures();
// if the cold chain was involved in a swap then we need to change the
// listener that does the logging and the destinations for the coldChainLoggers.
if (coldChain != oldColdChain) {
chains[oldColdChain].removeMarkovChainListener(chainListener);
// Set the new cold chain's loggers with the formatters (destinations) of
// the original cold chain
for (int i = 0; i < mcLoggers[coldChain].length; i++) {
mcLoggers[coldChain][i].setFormatters(logFormatters[i]);
}
// Set the old cold chain to have null log formatters...
for (int i = 0; i < mcLoggers[oldColdChain].length; i++) {
mcLoggers[oldColdChain][i].setFormatters(Collections.EMPTY_LIST);
}
chains[coldChain].addMarkovChainListener(chainListener);
}
for (int i = 0; i < chains.length; i++) {
threads[i].continueChain();
}
}
}
finish();
timer.stop();
}
private void runChains(long length, boolean disableCoerce) {
Thread[] threads = new Thread[chains.length];
for (int i = 0; i < chains.length; i++) {
threads[i] = new MCMCMCRunner(chains[i], length, length, false);
threads[i].start();
}
// wait for all threads collected to die
for (int i = 0; i < chains.length; i++) {
// wait doggedly for thread to die
while (threads[i].isAlive()) {
try {
threads[i].join();
} catch (InterruptedException ie) {
// DO NOTHING
}
}
}
}
private int swapChainTemperatures() {
if(DEBUG){
System.out.print("Current scores: ");
for(int i=0; i<chains.length; i++){
System.out.print("\t");
if(i==coldChain){
System.out.print("[");
}
System.out.print(chains[i].getCurrentScore());
if(i==coldChain){
System.out.print("]");
}
}
System.out.println();
}
int newColdChain = coldChain;
int index1 = MathUtils.nextInt(chains.length);
int index2 = MathUtils.nextInt(chains.length);
while (index1 == index2) {
index2 = MathUtils.nextInt(chains.length);
}
double score1 = chains[index1].getCurrentScore();
MCMCCriterion acceptor1 = ((MCMCCriterion) chains[index1].getAcceptor());
double temperature1 = acceptor1.getTemperature();
double score2 = chains[index2].getCurrentScore();
MCMCCriterion acceptor2 = ((MCMCCriterion) chains[index2].getAcceptor());
double temperature2 = acceptor2.getTemperature();
double logRatio = ((score2 - score1) * temperature1) + ((score1 - score2) * temperature2);
boolean swap = (Math.log(MathUtils.nextDouble()) < logRatio);
if (swap) {
if(DEBUG){
System.out.println("Swapping chain "+index1+" and chain "+index2);
}
acceptor1.setTemperature(temperature2);
acceptor2.setTemperature(temperature1);
OperatorSchedule schedule1 = schedules[index1];
OperatorSchedule schedule2 = schedules[index2];
for (int i = 0; i < schedule1.getOperatorCount(); i++) {
MCMCOperator operator1 = schedule1.getOperator(i);
MCMCOperator operator2 = schedule2.getOperator(i);
long tmp = operator1.getAcceptCount();
operator1.setAcceptCount(operator2.getAcceptCount());
operator2.setAcceptCount(tmp);
tmp = operator1.getRejectCount();
operator1.setRejectCount(operator2.getRejectCount());
operator2.setRejectCount(tmp);
double tmp2 = operator1.getSumDeviation();
operator1.setSumDeviation(operator2.getSumDeviation());
operator2.setSumDeviation(tmp2);
if (operator1 instanceof CoercableMCMCOperator) {
tmp2 = ((CoercableMCMCOperator) operator1).getCoercableParameter();
((CoercableMCMCOperator) operator1).setCoercableParameter(((CoercableMCMCOperator) operator2).getCoercableParameter());
((CoercableMCMCOperator) operator2).setCoercableParameter(tmp2);
}
}
if (index1 == coldChain) {
newColdChain = index2;
} else if (index2 == coldChain) {
newColdChain = index1;
}
}
return newColdChain;
}
private void resetChains() {
for (MarkovChain chain : chains) {
chain.reset();
}
}
/**
* cleans up when the chain finishes (possibly early).
*/
private void finish() {
NumberFormatter formatter = new NumberFormatter(8);
MCLogger[] loggers = mcLoggers[coldChain];
for (MCLogger logger : loggers) {
logger.log(currentState);
logger.stopLogging();
}
System.out.println();
System.out.println("Time taken: " + timer.toString());
if (showOperatorAnalysis) {
System.out.println();
System.out.println("Operator analysis");
System.out.println(
formatter.formatToFieldWidth("Operator", 30) +
formatter.formatToFieldWidth("", 8) +
formatter.formatToFieldWidth("Pr(accept)", 11) +
" Performance suggestion");
for (int i = 0; i < schedules[coldChain].getOperatorCount(); i++) {
MCMCOperator op = schedules[coldChain].getOperator(i);
double acceptanceProb = MCMCOperator.Utils.getAcceptanceProbability(op);
String message = "good";
if (acceptanceProb < op.getMinimumGoodAcceptanceLevel()) {
if (acceptanceProb < (op.getMinimumAcceptanceLevel() / 10.0)) {
message = "very low";
} else if (acceptanceProb < op.getMinimumAcceptanceLevel()) {
message = "low";
} else message = "slightly low";
} else if (acceptanceProb > op.getMaximumGoodAcceptanceLevel()) {
double reallyHigh = 1.0 - ((1.0 - op.getMaximumAcceptanceLevel()) / 10.0);
if (acceptanceProb > reallyHigh) {
message = "very high";
} else if (acceptanceProb > op.getMaximumAcceptanceLevel()) {
message = "high";
} else message = "slightly high";
}
String suggestion = op.getPerformanceSuggestion();
String pString = " ";
if (op instanceof CoercableMCMCOperator) {
pString = formatter.formatToFieldWidth(formatter.formatDecimal(((CoercableMCMCOperator) op).getRawParameter(), 3), 8);
}
System.out.println(
formatter.formatToFieldWidth(op.getOperatorName(), 30) +
pString +
formatter.formatToFieldWidth(formatter.formatDecimal(acceptanceProb, 4), 11) +
" " + message + "\t" + suggestion);
}
System.out.println();
}
}
private final MarkovChainListener chainListener = new MarkovChainListener() {
// MarkovChainListener interface *******************************************
// for receiving messages from subordinate MarkovChain
/**
* Called to update the current model keepEvery states.
*/
public synchronized void currentState(long state, MarkovChain markovChain, Model currentModel) {
currentState = state;
if (state % 1000 == 0) {
NumberFormatter formatter = new NumberFormatter(8);
formatter.setPadding(false);
//System.out.print("State " + currentState + ": ");
for (int i = 0; i < chains.length; i++) {
String score;
if (i == coldChain) {
score = "[" + formatter.format(chains[i].getCurrentScore()) + "]";
} else {
score = formatter.format(chains[i].getCurrentScore());
}
score += " ";
System.out.print(formatter.formatToFieldWidth(score, 12));
}
System.out.println();
}
MCLogger[] loggers = mcLoggers[coldChain];
for (MCLogger logger : loggers) {
logger.log(state);
}
}
/**
* Called when a new new best posterior state is found.
*/
public synchronized void bestState(long state, MarkovChain markovChain, Model bestModel) {
currentState = state;
}
/**
* cleans up when the chain finishes (possibly early).
*/
public synchronized void finished(long chainLength, MarkovChain markovChain) {
}
};
public int getColdChain() {
return coldChain;
}
/**
* @return the likelihood function.
*/
public Likelihood getLikelihood() {
return chains[coldChain].getLikelihood();
}
/**
* @return the timer.
*/
public dr.util.Timer getTimer() {
return timer;
}
/**
* @return the length of this analysis.
*/
public final long getChainLength() {
return mcmcOptions.getChainLength();
}
// TRANSIENT PUBLIC METHODS *****************************************
/**
* @return the current state of the MCMC analysis.
*/
public final long getCurrentState() {
return currentState;
}
/**
* @return the progress (0 to 1) of the MCMC analysis.
*/
public final double getProgress() {
return (double) currentState / (double) mcmcOptions.getChainLength();
}
/**
* Requests that the MCMC chain stop prematurely.
*/
public void pleaseStop() {
for (MarkovChain chain : chains) {
chain.pleaseStop();
}
}
public void setShowOperatorAnalysis(boolean soa) {
showOperatorAnalysis = soa;
}
// PRIVATE TRANSIENTS
private final MCMCOptions mcmcOptions;
private final MCMCMCOptions mcmcmcOptions;
private boolean showOperatorAnalysis = true;
private final dr.util.Timer timer = new dr.util.Timer();
private long currentState = 0;
private final MarkovChain[] chains;
private final MCLogger[][] mcLoggers;
private final OperatorSchedule[] schedules;
private int coldChain;
}