/*
* ParallelCompoundLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.model;
import dr.inference.parallel.MPIServices;
import dr.xml.*;
import java.util.*;
/**
* @author Marc A. Suchard
*/
public class ParallelCompoundLikelihood extends CompoundLikelihood {
public static final String PARALLEL_COMPOUND_LIKELIHOOD = "parallelCompoundLikelihood";
public static final String LOCAL_CHECK = "doLocalCheck";
public static final String RUN_PARALLEL = "doInParallel";
public ParallelCompoundLikelihood(Collection<Likelihood> likelihoods, boolean doParallel, boolean checkLocal) {
super(1, likelihoods);
this.doParallel = doParallel;
this.checkLocal = checkLocal;
}
private boolean doParallel = true;
private boolean checkLocal = false;
public double getLogLikelihood() {
double logLikelihood = 0;
if (doParallel) {
logLikelihood = getLogLikelihoodRemote();
if (checkLocal) {
super.makeDirty();
double logLikelihoodLocal = super.getLogLikelihood();
System.err.printf("Local: %5.4f Remote: %5.4f\n", logLikelihoodLocal, logLikelihood);
}
} else
logLikelihood = super.getLogLikelihood();
return logLikelihood;
}
private double getLogLikelihoodRemote() {
double logLikelihood = 0.0;
final int N = getLikelihoodCount();
List<ParallelLikelihood> likelihoodsDistributed = new ArrayList<ParallelLikelihood>();
List<Integer> processorList = new ArrayList<Integer>();
for (int i = 0; i < N; i++) {
ParallelLikelihood likelihood = (ParallelLikelihood) getLikelihood(i);
if (!likelihood.getLikelihoodKnown()) {
// if (true) {
final int processor = i + 1;
// MPIServices.requestLikelihood(processor);
// ((AbstractModel) getLikelihood(i).getModel()).sendState(processor);
likelihoodsDistributed.add(likelihood);
processorList.add(processor);
} else {
logLikelihood += likelihood.getLogLikelihood();
}
}
final int size = likelihoodsDistributed.size();
if (size == 1) { // only one, so do locally
logLikelihood += likelihoodsDistributed.get(0).getLogLikelihood();
} else if (size > 1) {
// Distribute calculations
int index = 0;
for (ParallelLikelihood likelihood : likelihoodsDistributed) {
int processor = processorList.get(index++);
MPIServices.requestLikelihood(processor);
((AbstractModel) likelihood.getModel()).sendState(processor);
}
// Implicit barrier
// Collect calculations
index = 0;
for (ParallelLikelihood likelihood : likelihoodsDistributed) {
int processor = processorList.get(index++);
double l = MPIServices.receiveDouble(processor);
logLikelihood += l;
likelihood.setLikelihood(l); // todo don't we need to set all of the submodels ????
}
// todo Use Gather instead of blocking loop
}
return logLikelihood;
}
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public String getParserName() {
return PARALLEL_COMPOUND_LIKELIHOOD;
}
// public String[] getParserNames() { return new String[] { getParserName(), "posterior", "prior", "likelihood" }; }
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
boolean doParallel = true;
boolean checkLocal = false;
if (xo.hasAttribute(LOCAL_CHECK)) {
checkLocal = xo.getBooleanAttribute(LOCAL_CHECK);
}
if (xo.hasAttribute(RUN_PARALLEL)) {
doParallel = xo.getBooleanAttribute(RUN_PARALLEL);
}
List<Likelihood> likelihoods = new ArrayList<Likelihood>();
for (int i = 0; i < xo.getChildCount(); i++) {
if (xo.getChild(i) instanceof Likelihood) {
likelihoods.add((Likelihood) xo.getChild(i));
} else {
throw new XMLParseException("An element which is not a likelihood has been added to a " + PARALLEL_COMPOUND_LIKELIHOOD + " element");
}
}
ParallelCompoundLikelihood compoundLikelihood = new ParallelCompoundLikelihood(likelihoods, doParallel, checkLocal);
return compoundLikelihood;
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "A likelihood function which is simply the product of its component likelihood functions.";
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{
new ElementRule(Likelihood.class, 1, Integer.MAX_VALUE),
};
public Class getReturnType() {
return ParallelCompoundLikelihood.class;
}
};
}