/*
* TestThreadedCompoundLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.model;
import dr.util.NumberFormatter;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
/**
* A likelihood function which is simply the product of a set of likelihood functions.
*
* @author Marc Suchard
* @author Andrew Rambaut
* @version $Id: CompoundLikelihood.java,v 1.19 2005/05/25 09:14:36 rambaut Exp $
*/
public class TestThreadedCompoundLikelihood implements Likelihood {
public static final boolean DEBUG = false;
public TestThreadedCompoundLikelihood() {
}
public TestThreadedCompoundLikelihood(List<Likelihood> likelihoods) {
for (Likelihood likelihood : likelihoods) {
addLikelihood(likelihood);
}
}
public void addLikelihood(Likelihood likelihood) {
if (!likelihoods.contains(likelihood)) {
likelihoods.add(likelihood);
if (likelihood.getModel() != null) {
compoundModel.addModel(likelihood.getModel());
}
likelihoodCallers.add(new LikelihoodCaller(likelihood));
}
}
public int getLikelihoodCount() {
return likelihoods.size();
}
public final Likelihood getLikelihood(int i) {
return likelihoods.get(i);
}
@Override
public Set<Likelihood> getLikelihoodSet() {
Set<Likelihood> set = new HashSet<Likelihood>();
for (Likelihood l : likelihoods) {
set.add(l);
set.addAll(l.getLikelihoodSet());
}
return set;
}
// **************************************************************
// Likelihood IMPLEMENTATION
// **************************************************************
public Model getModel() {
return compoundModel;
}
public double getLogLikelihood() {
double logLikelihood = 0.0;
if (threads == null) {
// first call so setup a thread for each likelihood...
threads = new LikelihoodThread[likelihoodCallers.size()];
for (int i = 0; i < threads.length; i++) {
// and start them running...
threads[i] = new LikelihoodThread();
threads[i].start();
}
}
for (int i = 0; i < threads.length; i++) {
// set the caller which will be called in each thread
LikelihoodCaller caller = likelihoodCallers.get(i);
if (caller.isLikelihoodKnown()) {
threads[i].setReturnValue(caller.call());
} else {
threads[i].setCaller(caller);
}
}
for (LikelihoodThread thread : threads) {
// now wait for the results to be set...
Double result = thread.getResult();
while (result == null) {
result = thread.getResult();
}
logLikelihood += result;
}
return logLikelihood; // * weightFactor;
}
public boolean evaluateEarly() {
return false;
}
public void makeDirty() {
for (Likelihood likelihood : likelihoods) {
likelihood.makeDirty();
}
}
public String prettyName() {
return Abstract.getPrettyName(this);
}
public String getDiagnosis() {
String message = "";
boolean first = true;
for (Likelihood lik : likelihoods) {
if (!first) {
message += ", ";
} else {
first = false;
}
String id = lik.getId();
if (id == null || id.trim().length() == 0) {
String[] parts = lik.getClass().getName().split("\\.");
id = parts[parts.length - 1];
}
message += id + "=";
if (lik instanceof TestThreadedCompoundLikelihood) {
String d = ((TestThreadedCompoundLikelihood) lik).getDiagnosis();
if (d != null && d.length() > 0) {
message += "(" + d + ")";
}
} else {
if (lik.getLogLikelihood() == Double.NEGATIVE_INFINITY) {
message += "-Inf";
} else if (Double.isNaN(lik.getLogLikelihood())) {
message += "NaN";
} else {
NumberFormatter nf = new NumberFormatter(6);
message += nf.formatDecimal(lik.getLogLikelihood(), 4);
}
}
}
return message;
}
public String toString() {
return Double.toString(getLogLikelihood());
}
public void setWeightFactor(double w) {
weightFactor = w;
}
public double getWeightFactor() {
return weightFactor;
}
// **************************************************************
// Loggable IMPLEMENTATION
// **************************************************************
/**
* @return the log columns.
*/
public dr.inference.loggers.LogColumn[] getColumns() {
return new dr.inference.loggers.LogColumn[]{
new LikelihoodColumn(getId())
};
}
private class LikelihoodColumn extends dr.inference.loggers.NumberColumn {
public LikelihoodColumn(String label) {
super(label);
}
public double getDoubleValue() {
return getLogLikelihood();
}
}
// **************************************************************
// Identifiable IMPLEMENTATION
// **************************************************************
private String id = null;
public void setId(String id) {
this.id = id;
}
public String getId() {
return id;
}
private LikelihoodThread[] threads;
private final ArrayList<Likelihood> likelihoods = new ArrayList<Likelihood>();
private final CompoundModel compoundModel = new CompoundModel("compoundModel");
private final List<LikelihoodCaller> likelihoodCallers = new ArrayList<LikelihoodCaller>();
private double weightFactor = 1.0;
class LikelihoodCaller {
public LikelihoodCaller(Likelihood likelihood) {
this.likelihood = likelihood;
}
public double call() {
return likelihood.getLogLikelihood();
}
private final Likelihood likelihood;
public boolean isLikelihoodKnown() {
return ((likelihood instanceof ThreadAwareLikelihood) &&
((ThreadAwareLikelihood) likelihood).isLikelihoodKnown());
}
}
class LikelihoodThread extends Thread {
public LikelihoodThread() {
}
public void setCaller(LikelihoodCaller caller) {
lock.lock();
resultAvailable = false;
try {
this.caller = caller;
condition.signal();
} finally {
lock.unlock();
}
}
/**
* Main run loop
*/
public void run() {
while (true) {
lock.lock();
try {
while (caller == null)
condition.await();
result = caller.call(); // SLOW
resultAvailable = true;
caller = null;
} catch (InterruptedException e) {
} finally {
lock.unlock();
}
}
}
public Double getResult() {
Double returnValue = null;
if (!lock.isLocked() && resultAvailable) { // thread is not busy and completed
resultAvailable = false; // TODO need to lock before changing resultAvailable?
returnValue = result;
}
return returnValue;
}
private LikelihoodCaller caller = null;
private Double result = Double.NaN;
private boolean resultAvailable = false;
private final ReentrantLock lock = new ReentrantLock();
private final Condition condition = lock.newCondition();
public void setReturnValue(double logLikelihood) {
result = logLikelihood;
resultAvailable = true;
}
}
public boolean isUsed() {
return isUsed;
}
public void setUsed() {
isUsed = true;
for (Likelihood l : likelihoods) {
l.setUsed();
}
}
private boolean isUsed = false;
}