/*
* VeryOldCoalescentLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.arg.coalescent;
import dr.evolution.coalescent.DemographicFunction;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.Units;
import dr.evomodel.coalescent.DemographicModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.*;
import dr.math.Binomial;
import dr.util.ComparableDouble;
import dr.util.HeapSort;
import dr.xml.*;
import java.util.ArrayList;
/**
* A likelihood function for the coalescent. Takes a tree and a demographic model.
* <p/>
* Parts of this class were derived from C++ code provided by Oliver Pybus.
*
* @author Andrew Rambaut
* @author Alexei Drummond
* @version $Id: VeryOldCoalescentLikelihood.java,v 1.43 2006/07/28 11:27:32 rambaut Exp $
*/
public class VeryOldCoalescentLikelihood extends AbstractModelLikelihood implements Units {
// PUBLIC STUFF
public static final String COALESCENT_LIKELIHOOD = "veryOldCoalescentLikelihood";
public static final String ANALYTICAL = "analytical";
public static final String MODEL = "model";
public static final String POPULATION_TREE = "populationTree";
/**
* Denotes an interval after which a coalescent event is observed
* (i.e. the number of lineages is smaller in the next interval)
*/
public static final int COALESCENT = 0;
/**
* Denotes an interval at the end of which a new sample addition is
* observed (i.e. the number of lineages is larger in the next interval).
*/
public static final int NEW_SAMPLE = 1;
/**
* Denotes an interval at the end of which nothing is
* observed (i.e. the number of lineages is the same in the next interval).
*/
public static final int NOTHING = 2;
public VeryOldCoalescentLikelihood(Tree tree, DemographicModel demoModel) {
this(COALESCENT_LIKELIHOOD, tree, demoModel, true);
}
public VeryOldCoalescentLikelihood(String name, Tree tree, DemographicModel demoModel, boolean setupIntervals) {
super(name);
this.tree = tree;
this.demoModel = demoModel;
if (tree instanceof TreeModel) {
addModel((TreeModel) tree);
}
if (demoModel != null) {
addModel(demoModel);
}
if (setupIntervals) setupIntervals();
addStatistic(new DeltaStatistic());
}
VeryOldCoalescentLikelihood(String name) {
super(name);
}
// **************************************************************
// Extendable methods
// **************************************************************
/**
* @return the node ref of the MRCA of this coalescent prior in the given tree.
*/
public NodeRef getMRCAOfCoalescent(Tree tree) {
return tree.getRoot();
}
/**
* @return an array of noderefs that represent the MRCAs of subtrees to exclude from coalescent prior.
* May return null if no subtrees should be excluded.
*/
public NodeRef[] getExcludedMRCAs(Tree tree) {
return null;
}
// **************************************************************
// ModelListener IMPLEMENTATION
// **************************************************************
protected void handleModelChangedEvent(Model model, Object object, int index) {
if (model == tree) {
// treeModel has changed so recalculate the intervals
intervalsKnown = false;
} else {
// demoModel has changed so we don't need to recalculate the intervals
}
likelihoodKnown = false;
}
// **************************************************************
// VariableListener IMPLEMENTATION
// **************************************************************
protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
} // No parameters to respond to
// **************************************************************
// Model IMPLEMENTATION
// **************************************************************
/**
* Stores the precalculated state: in this case the intervals
*/
protected void storeState() {
System.arraycopy(intervals, 0, storedIntervals, 0, intervals.length);
System.arraycopy(lineageCounts, 0, storedLineageCounts, 0, lineageCounts.length);
storedIntervalsKnown = intervalsKnown;
storedIntervalCount = intervalCount;
storedLikelihoodKnown = likelihoodKnown;
storedLogLikelihood = logLikelihood;
}
/**
* Restores the precalculated state: that is the intervals of the tree.
*/
protected void restoreState() {
System.arraycopy(storedIntervals, 0, intervals, 0, storedIntervals.length);
System.arraycopy(storedLineageCounts, 0, lineageCounts, 0, storedLineageCounts.length);
intervalsKnown = storedIntervalsKnown;
intervalCount = storedIntervalCount;
likelihoodKnown = storedLikelihoodKnown;
logLikelihood = storedLogLikelihood;
if (!intervalsKnown) {
likelihoodKnown = false;
}
}
protected final void acceptState() {
} // nothing to do
/**
* Adopt the state of the model from source.
*/
protected final void adoptState(Model source) {
// all we need to do is force a recalculation of intervals
makeDirty();
}
// **************************************************************
// Likelihood IMPLEMENTATION
// **************************************************************
public final Model getModel() {
return this;
}
public double getLogLikelihood() {
if (!likelihoodKnown) {
logLikelihood = calculateLogLikelihood();
likelihoodKnown = true;
}
return logLikelihood;
}
public final void makeDirty() {
likelihoodKnown = false;
intervalsKnown = false;
}
/**
* Calculates the log likelihood of this set of coalescent intervals,
* given a demographic model.
*/
public double calculateLogLikelihood() {
if ( !intervalsKnown ) setupIntervals();
if (demoModel == null) return calculateAnalyticalLogLikelihood();
double logL = 0.0;
double currentTime = 0.0;
DemographicFunction demoFunction = demoModel.getDemographicFunction();
for (int j = 0; j < intervalCount; j++) {
logL += calculateIntervalLikelihood(demoFunction, intervals[j], currentTime, lineageCounts[j],
getIntervalType(j));
// insert zero-length coalescent intervals
int diff = getCoalescentEvents(j) - 1;
for (int k = 0; k < diff; k++) {
logL += calculateIntervalLikelihood(demoFunction, 0.0, currentTime, lineageCounts[j] - k - 1, COALESCENT);
}
currentTime += intervals[j];
}
return logL;
}
private double calculateAnalyticalLogLikelihood() {
final double lambda = getLambda();
final int n = tree.getExternalNodeCount();
// assumes a 1/theta prior
//logLikelihood = Math.log(1.0/Math.pow(lambda,n));
// assumes a flat prior
double logL = Math.log(1.0 / Math.pow(lambda, n - 1));
return logL;
}
/**
* Returns the likelihood of a given *coalescent* interval
*/
public final double calculateIntervalLikelihood(DemographicFunction demoFunction, double width, double timeOfPrevCoal, int lineageCount) {
return calculateIntervalLikelihood(demoFunction, width, timeOfPrevCoal, lineageCount, COALESCENT);
}
/**
* Returns the likelihood of a given interval,coalescent or otherwise.
*/
public final double calculateIntervalLikelihood(DemographicFunction demoFunction, double width, double timeOfPrevCoal,
int lineageCount, int type) {
//binom.setMax(lineageCount);
double timeOfThisCoal = width + timeOfPrevCoal;
// System.err.printf("s: %7.6f f: %7.6f, %d, %d\n", timeOfPrevCoal, timeOfThisCoal, lineageCount, type);
double intervalArea = demoFunction.getIntegral(timeOfPrevCoal, timeOfThisCoal);
double like = 0;
switch (type) {
case COALESCENT:
like =
-Math.log(demoFunction.getDemographic(timeOfThisCoal)) -
(Binomial.choose2(lineageCount) * intervalArea);
break;
case NEW_SAMPLE:
like = -(Binomial.choose2(lineageCount) * intervalArea);
break;
}
return like;
}
/**
* Returns a factor lambda such that the likelihood can be expressed as
* 1/theta^(n-1) * exp(-lambda/theta). This allows theta to be integrated
* out analytically. :-)
*/
private double getLambda() {
double lambda = 0.0;
for (int i = 0; i < getIntervalCount(); i++) {
lambda += (intervals[i] * lineageCounts[i]);
}
lambda /= 2;
return lambda;
}
/**
* Recalculates all the intervals from the tree model.
*/
protected final void setupIntervals() {
double MULTIFURCATION_LIMIT = 1e-9;
ArrayList times = new ArrayList();
ArrayList<Integer> childs = new ArrayList<Integer>();
collectAllTimes(tree, getMRCAOfCoalescent(tree), getExcludedMRCAs(tree), times, childs);
int[] indices = new int[times.size()];
HeapSort.sort(times, indices);
int maxIntervalCount = tree.getNodeCount();
if (intervals == null) {
intervals = new double[maxIntervalCount];
lineageCounts = new int[maxIntervalCount];
storedIntervals = new double[maxIntervalCount];
storedLineageCounts = new int[maxIntervalCount];
}
// start is the time of the first tip
double start = ((ComparableDouble) times.get(indices[0])).doubleValue();
int numLines = 0;
int i = 0;
intervalCount = 0;
while (i < times.size()) {
int lineagesRemoved = 0;
int lineagesAdded = 0;
double finish = ((ComparableDouble) times.get(indices[i])).doubleValue();
double next = finish;
while (Math.abs(next - finish) < MULTIFURCATION_LIMIT) {
int children = childs.get(indices[i]);
if (children == 0) {
lineagesAdded += 1;
} else {
lineagesRemoved += (children - 1);
}
i += 1;
if (i < times.size()) {
next = ((ComparableDouble) times.get(indices[i])).doubleValue();
} else break;
}
//System.out.println("time = " + finish + " removed = " + lineagesRemoved + " added = " + lineagesAdded);
if (lineagesAdded > 0) {
if (intervalCount > 0 || ((finish - start) > MULTIFURCATION_LIMIT)) {
intervals[intervalCount] = finish - start;
lineageCounts[intervalCount] = numLines;
intervalCount += 1;
}
start = finish;
}
// add sample event
numLines += lineagesAdded;
if (lineagesRemoved > 0) {
intervals[intervalCount] = finish - start;
lineageCounts[intervalCount] = numLines;
intervalCount += 1;
start = finish;
}
// coalescent event
numLines -= lineagesRemoved;
}
intervalsKnown = true;
}
/**
* extract coalescent times and tip information into ArrayList times from tree.
*
* @param node the node to start from
* @param excludeBelow an optional array of nodes to exclude (corresponding subtrees) from density.
*/
private static void collectAllTimes(Tree tree, NodeRef node, NodeRef[] excludeBelow, ArrayList times, ArrayList<Integer> childs) {
times.add(new ComparableDouble(tree.getNodeHeight(node)));
childs.add(tree.getChildCount(node));
for (int i = 0; i < tree.getChildCount(node); i++) {
NodeRef child = tree.getChild(node, i);
if (excludeBelow == null) {
collectAllTimes(tree, child, excludeBelow, times, childs);
} else {
// check if this subtree is included in the coalescent density
boolean include = true;
for(NodeRef anExcludeBelow : excludeBelow) {
if( anExcludeBelow.getNumber() == child.getNumber() ) {
include = false;
break;
}
}
if (include) collectAllTimes(tree, child, excludeBelow, times, childs);
}
}
}
/**
* get number of intervals
*/
public final int getIntervalCount() {
return intervalCount;
}
/**
* Gets an interval.
*/
public final double getInterval(int i) {
if (i >= intervalCount) throw new IllegalArgumentException();
return intervals[i];
}
/**
* Returns the number of uncoalesced lineages within this interval.
* Required for s-coalescents, where new lineages are added as
* earlier samples are come across.
*/
public final int getLineageCount(int i) {
if (i >= intervalCount) throw new IllegalArgumentException();
return lineageCounts[i];
}
/**
* Returns the number coalescent events in an interval
*/
public final int getCoalescentEvents(int i) {
if (i >= intervalCount) throw new IllegalArgumentException();
if (i < intervalCount - 1) {
return lineageCounts[i] - lineageCounts[i + 1];
} else {
return lineageCounts[i] - 1;
}
}
/**
* Returns the type of interval observed.
*/
public final int getIntervalType(int i) {
if (i >= intervalCount) throw new IllegalArgumentException();
int numEvents = getCoalescentEvents(i);
if (numEvents > 0) return COALESCENT;
else if (numEvents < 0) return NEW_SAMPLE;
else return NOTHING;
}
/**
* get the total height of the genealogy represented by these
* intervals.
*/
public final double getTotalHeight() {
double height = 0.0;
for (int j = 0; j < intervalCount; j++) {
height += intervals[j];
}
return height;
}
/**
* Checks whether this set of coalescent intervals is fully resolved
* (i.e. whether is has exactly one coalescent event in each
* subsequent interval)
*/
public final boolean isBinaryCoalescent() {
for (int i = 0; i < intervalCount; i++) {
if (getCoalescentEvents(i) != 1) return false;
}
return true;
}
/**
* Checks whether this set of coalescent intervals coalescent only
* (i.e. whether is has exactly one or more coalescent event in each
* subsequent interval)
*/
public final boolean isCoalescentOnly() {
for (int i = 0; i < intervalCount; i++) {
if (getCoalescentEvents(i) < 1) return false;
}
return true;
}
public String toString() {
return Double.toString(getLogLikelihood());
}
// **************************************************************
// Units IMPLEMENTATION
// **************************************************************
public final Type getUnits() {
return demoModel.getUnits();
}
public void setUnits(Type units) {
demoModel.setUnits(units);
}
// ****************************************************************
// Inner classes
// ****************************************************************
public class DeltaStatistic extends Statistic.Abstract {
public DeltaStatistic() {
super("delta");
}
public int getDimension() {
return 1;
}
public double getStatisticValue(int i) {
throw new RuntimeException("Not implemented");
// return IntervalList.Utils.getDelta(intervals);
}
}
// ****************************************************************
// Private and protected stuff
// ****************************************************************
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public String getParserName() {
return COALESCENT_LIKELIHOOD;
}
public Object parseXMLObject(XMLObject xo) {
XMLObject cxo = xo.getChild(MODEL);
DemographicModel demoModel = (DemographicModel) cxo.getChild(DemographicModel.class);
cxo = xo.getChild(POPULATION_TREE);
TreeModel treeModel = (TreeModel) cxo.getChild(TreeModel.class);
return new VeryOldCoalescentLikelihood(treeModel, demoModel);
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "This element represents the likelihood of the tree given the demographic function.";
}
public Class getReturnType() {
return Likelihood.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
new ElementRule(MODEL, new XMLSyntaxRule[]{
new ElementRule(DemographicModel.class)
}),
new ElementRule(POPULATION_TREE, new XMLSyntaxRule[]{
new ElementRule(TreeModel.class)
}),
};
};
// ****************************************************************
// Private and protected stuff
// ****************************************************************
/* public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public String getParserName() { return COALESCENT_LIKELIHOOD; }
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
DemographicModel demoModel = null;
if (xo.hasAttribute(MODEL)) {
demoModel = (DemographicModel)xo.getAttribute(MODEL);
}
TreeModel treeModel = (TreeModel)xo.getAttribute(TREE);
return new VeryOldCoalescentLikelihood(treeModel, demoModel);
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "This element represents the likelihood of the tree given the demographic function.";
}
public Class getReturnType() { return Likelihood.class; }
public XMLSyntaxRule[] getSyntaxRules() { return rules; }
private XMLSyntaxRule[] rules = new XMLSyntaxRule[] {
new XORRule(
new EnumAttributeRule(ANALYTICAL, new String[] { "constant" }),
new AttributeRule(MODEL, DemographicModel.class)
),
new AttributeRule(TREE, TreeModel.class)
};
};*/
/**
* The demographic model.
*/
DemographicModel demoModel = null;
/**
* The tree.
*/
Tree tree = null;
/**
* The widths of the intervals.
*/
double[] intervals;
private double[] storedIntervals;
/**
* The number of uncoalesced lineages within a particular interval.
*/
int[] lineageCounts;
private int[] storedLineageCounts;
boolean intervalsKnown = false;
protected boolean storedIntervalsKnown = false;
double logLikelihood;
protected double storedLogLikelihood;
boolean likelihoodKnown = false;
protected boolean storedLikelihoodKnown = false;
int intervalCount = 0;
private int storedIntervalCount = 0;
}