/*
* UniformNodeHeightPrior.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.tree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodelxml.tree.UniformNodeHeightPriorParser;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.LogTricks;
import dr.math.MathUtils;
import dr.math.Polynomial;
import java.util.*;
import java.util.logging.Logger;
//import org.jscience.mathematics.number.Rational;
/**
* Two priors for the tree that are relatively non-informative on the internal node heights given the root height.
* The first further assumes that the root height is truncated uniform, see Nicholls, G. & R.D. Gray (2004) for details.
* The second allows any marginal specification over the root height given that it is larger than the oldest
* sampling time (Bloomquist and Suchard, unpublished).
*
* @author Alexei Drummond
* @author Andrew Rambaut
* @author Erik Bloomquist
* @author Marc Suchard
* @version $Id: UniformRootPrior.java,v 1.10 2005/05/24 20:25:58 rambaut Exp $
*/
public class UniformNodeHeightPrior extends AbstractModelLikelihood {
// PUBLIC STUFF
public static final int MAX_ANALYTIC_TIPS = 60; // TODO Determine this value!
public static final int DEFAULT_MC_SAMPLE = 100000;
private static final double tolerance = 1E-6;
private int k = 0;
private double logFactorialK;
private double maxRootHeight;
private boolean isNicholls;
private boolean useAnalytic;
private boolean useMarginal;
private boolean leadingTerm;
private int mcSampleSize;
Set<Double> tipDates = new TreeSet<Double>();
List<Double> reversedTipDateList = new ArrayList<Double>();
Map<Double, Integer> intervals = new TreeMap<Double, Integer>();
public UniformNodeHeightPrior(Tree tree, boolean useAnalytic, boolean marginal, boolean leadingTerm) {
this(UniformNodeHeightPriorParser.UNIFORM_NODE_HEIGHT_PRIOR, tree, useAnalytic, DEFAULT_MC_SAMPLE, marginal, leadingTerm);
}
public UniformNodeHeightPrior(Tree tree, boolean useAnalytic, int mcSampleSize) {
this(UniformNodeHeightPriorParser.UNIFORM_NODE_HEIGHT_PRIOR,tree,useAnalytic,mcSampleSize, false, false);
}
private UniformNodeHeightPrior(String name, Tree tree, boolean useAnalytic, int mcSampleSize,
boolean marginal, boolean leadingTerm) {
super(name);
this.tree = tree;
this.isNicholls = false;
this.useAnalytic = useAnalytic;
this.useMarginal = marginal;
this.mcSampleSize = mcSampleSize;
this.leadingTerm = leadingTerm;
if (tree instanceof TreeModel) {
addModel((TreeModel) tree);
}
for (int i = 0; i < tree.getExternalNodeCount(); i++) {
double h = tree.getNodeHeight(tree.getExternalNode(i));
tipDates.add(h);
}
if (tipDates.size() == 1 || leadingTerm) {
// the tips are contemporaneous so these are constant...
k = tree.getInternalNodeCount() - 1;
Logger.getLogger("dr.evomodel").info("Uniform Node Height Prior, Intervals = " + (k + 1));
logFactorialK = logFactorial(k);
} else {
reversedTipDateList.addAll(tipDates);
Collections.reverse(reversedTipDateList);
// Prune out intervals smaller in length than tolerance
double intervalStart = tree.getNodeHeight(tree.getRoot());
List<Double> pruneDates = new ArrayList<Double>();
for (Double intervalEnd : reversedTipDateList) {
if (intervalStart - intervalEnd < tolerance) {
pruneDates.add(intervalStart);
}
intervalStart = intervalEnd;
}
for (Double date : pruneDates)
reversedTipDateList.remove(date);
if (!useAnalytic) {
logLikelihoods = new double[mcSampleSize];
drawNodeHeights = new double[tree.getNodeCount()][mcSampleSize];
minNodeHeights = new double[tree.getNodeCount()];
}
}
// Leading coefficient on tree polynomial is X = (# internal nodes)!
// To keep X > 10E-40, should use log-space polynomials for more than ~30 tips
if (tree.getExternalNodeCount() < 30) {
polynomialType = Polynomial.Type.DOUBLE; // Much faster
} else if (tree.getExternalNodeCount() < 45){
polynomialType = Polynomial.Type.LOG_DOUBLE;
} else {
// polynomialType = Polynomial.Type.APDOUBLE;
polynomialType = Polynomial.Type.LOG_DOUBLE;
}
Logger.getLogger("dr.evomodel").info("Using "+polynomialType+" polynomials!");
}
public UniformNodeHeightPrior(Tree tree, double maxRootHeight) {
this(UniformNodeHeightPriorParser.UNIFORM_NODE_HEIGHT_PRIOR, tree, maxRootHeight);
}
private UniformNodeHeightPrior(String name, Tree tree, double maxRootHeight) {
super(name);
this.tree = tree;
this.maxRootHeight = maxRootHeight;
isNicholls = true;
if (tree instanceof TreeModel) {
addModel((TreeModel) tree);
}
}
UniformNodeHeightPrior(String name) {
super(name);
}
// **************************************************************
// Extendable methods
// **************************************************************
// **************************************************************
// ModelListener IMPLEMENTATION
// **************************************************************
protected final void handleModelChangedEvent(Model model, Object object, int index) {
likelihoodKnown = false;
treePolynomialKnown = false;
return;
// Only set treePolynomialKnown = false when a topology change occurs
// Only set likelihoodKnown = false when a topology change occurs or the rootHeight is changed
// if (model == tree) {
// if (object instanceof TreeModel.TreeChangedEvent) {
// TreeModel.TreeChangedEvent event = (TreeModel.TreeChangedEvent) object;
// if (event.isHeightChanged()) {
// if (event.getNode() == tree.getRoot()) {
// likelihoodKnown = false;
// return;
// } // else
// return;
// }
// if (event.isNodeParameterChanged())
// return;
// // All others are probably tree structure changes
// likelihoodKnown = false;
// treePolynomialKnown = false;
// return;
// }
// // TODO Why are not all node height changes invoking TreeChangedEvents?
// if (object instanceof Parameter.Default) {
// Parameter parameter = (Parameter) object;
// if (tree.getNodeHeight(tree.getRoot()) == parameter.getParameterValue(index)) {
// likelihoodKnown = false;
// treePolynomialKnown = false;
// return;
// }
// return;
// }
// }
//
// throw new RuntimeException("Unexpected event!");
}
// **************************************************************
// VariableListener IMPLEMENTATION
// **************************************************************
protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
}
// **************************************************************
// Model IMPLEMENTATION
// **************************************************************
/**
* Stores the precalculated state: in this case the intervals
*/
protected final void storeState() {
storedLikelihoodKnown = likelihoodKnown;
storedLogLikelihood = logLikelihood;
// storedTreePolynomialKnown = treePolynomialKnown;
// if (treePolynomial != null)
// storedTreePolynomial = treePolynomial.copy(); // TODO Swap pointers
}
/**
* Restores the precalculated state: that is the intervals of the tree.
*/
protected final void restoreState() {
likelihoodKnown = storedLikelihoodKnown;
logLikelihood = storedLogLikelihood;
// treePolynomialKnown = storedTreePolynomialKnown;
// treePolynomial = storedTreePolynomial;
}
protected final void acceptState() {
} // nothing to do
// **************************************************************
// Likelihood IMPLEMENTATION
// **************************************************************
public final Model getModel() {
return this;
}
public double getLogLikelihood() {
// return calculateLogLikelihood();
if (!likelihoodKnown) {
logLikelihood = calculateLogLikelihood();
likelihoodKnown = true;
}
return logLikelihood;
}
public final void makeDirty() {
likelihoodKnown = false;
treePolynomialKnown = false;
}
public double calculateLogLikelihood() {
double rootHeight = tree.getNodeHeight(tree.getRoot());
if (isNicholls) {
int nodeCount = tree.getExternalNodeCount();
if (rootHeight < 0 || rootHeight > (0.999 * maxRootHeight)) return Double.NEGATIVE_INFINITY;
// from Nicholls, G. & R.D. Gray (2004)
return rootHeight * (2 - nodeCount) - Math.log(maxRootHeight - rootHeight);
} else {
// the Bloomquist & Suchard variant
// Let the sampling times and rootHeight specify the boundaries between a fixed number of intervals.
// Internal node heights are equally likely to fall in any of these intervals and uniformly distributed
// in an interval before sorting (i.e. the intercoalescent times in an interval form a scaled Dirchelet(1,1,\ldots,1)
// This is a conditional density on the rootHeight, so it is possible to specify a marginal distribution
// on the rootHeight given it is greater than the oldest sampling time.
double logLike;
if (k > 0) { // Also valid for leading-term approximation
// the tips are contemporaneous
logLike = logFactorialK - (double) k * Math.log(rootHeight);
// double cutoff = 62;
// int count = 0;
// for (int i = 0; i < tree.getNodeCount(); i++) {
// if (tree.getNodeHeight(tree.getNode(i)) > cutoff) {
// count++;
//// if (tree.isExternal(tree.getNode(i))) {
//// System.err.println("Problem");
//// System.exit(-1);
//// }
// }
// }
// count -= 1; // ignore root
//// System.err.println("c = " + count);
// logLike = logFactorial(count) - (double) count * Math.log(rootHeight - cutoff);
// logLike = logFactorial(k - count) - (double) (k - count) * Math.log(cutoff);
} else {
// TODO Rewrite description above to discuss this new prior
if (useAnalytic) {
// long startTime1 = System.nanoTime();
if (useMarginal) {
if (!treePolynomialKnown) {
// treePolynomial = recursivelyComputePolynomial(tree, tree.getRoot(), polynomialType).getPolynomial();
treePolynomials = constructRootPolyonmials(tree,polynomialType); // Each polynomial is of lower degree
treePolynomialKnown = true;
}
// logLike = -treePolynomial.logEvaluate(rootHeight);
logLike = -treePolynomials[0].logEvaluate(rootHeight) - treePolynomials[1].logEvaluate(rootHeight);
if (Double.isNaN(logLike)) {
// Try using Horner's method
// logLike = -treePolynomial.logEvaluateHorner(rootHeight); // TODO this could be causing the problem!!
logLike = -treePolynomials[0].logEvaluateHorner(rootHeight) - treePolynomials[1].logEvaluateHorner(rootHeight);
if (Double.isNaN(logLike)) {
logLike = Double.NEGATIVE_INFINITY;
}
}
} else {
tmpLogLikelihood = 0;
recursivelyComputeDensity(tree, tree.getRoot(), 0);
logLike = tmpLogLikelihood;
}
// long stopTime1 = System.nanoTime();
} else {
// long startTime2 = System.nanoTime();
// Copy over current root height
final double[] drawRootHeight = drawNodeHeights[tree.getRoot().getNumber()];
Arrays.fill(drawRootHeight,rootHeight); // TODO Only update when rootHeight changes
// Determine min heights for each node in tree
recursivelyFindNodeMinHeights(tree,tree.getRoot()); // TODO Only update when topology changes
// Simulate from prior
Arrays.fill(logLikelihoods,0.0);
recursivelyComputeMCIntegral(tree, tree.getRoot(), tree.getRoot().getNumber()); // TODO Only update when topology or rootHeight changes
// Take average
logLike = -LogTricks.logSum(logLikelihoods) + Math.log(mcSampleSize);
// long stopTime2 = System.nanoTime();
}
}
assert !Double.isInfinite(logLike) && !Double.isNaN(logLike);
return logLike;
}
}
// Map<Double,Integer> boxCounts;
//
// private double recursivelyComputeMarcDensity(Tree tree, NodeRef node, double rootHeight) {
// if (tree.isExternal(node))
// return tree.getNodeHeight(node);
//
//// double thisHeight = tree.getNodeHeight(node);
//// double thisHeight = rootHeight;
// double heightChild1 = recursivelyComputeMarcDensity(tree, tree.getChild(node, 0), rootHeight);
// double heightChild2 = recursivelyComputeMarcDensity(tree, tree.getChild(node, 1), rootHeight);
// double minHeight = (heightChild1 > heightChild2) ? heightChild1 : heightChild2;
//
// if (!tree.isRoot(node)) {
// double diff = rootHeight - minHeight;
// if (diff <= 0)
// tmpLogLikelihood = Double.NEGATIVE_INFINITY;
// else
// tmpLogLikelihood -= Math.log(diff);
//
// Integer count = boxCounts.get(minHeight);
// if (count == null) {
// boxCounts.put(minHeight,1);
//// System.err.println("new height: "+minHeight);
// } else {
// boxCounts.put(minHeight,count+1);
//// System.err.println("old height: "+minHeight);
// }
// // TODO Could do the logFactorial right here
// } else {
// // Do nothing
// }
// return minHeight;
// }
private double recursivelyComputeDensity(Tree tree, NodeRef node, double parentHeight) {
if (tree.isExternal(node))
return tree.getNodeHeight(node);
double thisHeight = tree.getNodeHeight(node);
double heightChild1 = recursivelyComputeDensity(tree, tree.getChild(node, 0), thisHeight);
double heightChild2 = recursivelyComputeDensity(tree, tree.getChild(node, 1), thisHeight);
double minHeight = (heightChild1 > heightChild2) ? heightChild1 : heightChild2;
if (!tree.isRoot(node)) {
double diff = parentHeight - minHeight;
if (diff <= 0)
tmpLogLikelihood = Double.NEGATIVE_INFINITY;
else
tmpLogLikelihood -= Math.log(diff);
// tmpLogLikelihood -= Math.log(parentHeight-minHeight);
} else {
// Do nothing
}
return minHeight;
}
private double recursivelyFindNodeMinHeights(Tree tree, NodeRef node) {
double minHeight;
if (tree.isExternal(node))
minHeight = tree.getNodeHeight(node);
else {
double minHeightChild0 = recursivelyFindNodeMinHeights(tree, tree.getChild(node,0));
double minHeightChild1 = recursivelyFindNodeMinHeights(tree, tree.getChild(node,1));
minHeight = (minHeightChild0 > minHeightChild1) ? minHeightChild0 : minHeightChild1;
}
minNodeHeights[node.getNumber()] = minHeight;
return minHeight;
}
private void recursivelyComputeMCIntegral(Tree tree, NodeRef node, int parentNodeNumber) {
if (tree.isExternal(node))
return;
final int nodeNumber = node.getNumber();
if (!tree.isRoot(node)) {
final double[] drawParentHeight = drawNodeHeights[parentNodeNumber];
final double[] drawThisNodeHeight = drawNodeHeights[nodeNumber];
final double minHeight = minNodeHeights[nodeNumber];
final boolean twoChild = (tree.isExternal(tree.getChild(node,0)) && tree.isExternal(tree.getChild(node,1)));
for(int i=0; i<mcSampleSize; i++) {
final double diff = drawParentHeight[i] - minHeight;
if (diff <= 0) {
logLikelihoods[i] = Double.NEGATIVE_INFINITY;
break;
}
if (!twoChild)
drawThisNodeHeight[i] = MathUtils.nextDouble() * diff + minHeight;
logLikelihoods[i] += Math.log(diff);
}
}
recursivelyComputeMCIntegral(tree, tree.getChild(node,0), nodeNumber);
recursivelyComputeMCIntegral(tree, tree.getChild(node,1), nodeNumber);
}
private static final double INV_PRECISION = 10;
private static double round(double x) {
return Math.round(x * INV_PRECISION) / INV_PRECISION;
}
private Polynomial[] constructRootPolyonmials(Tree tree, Polynomial.Type type) {
NodeRef root = tree.getRoot();
return new Polynomial[] {
recursivelyComputePolynomial(tree,tree.getChild(root,0),type).getPolynomial(),
recursivelyComputePolynomial(tree,tree.getChild(root,1),type).getPolynomial()
};
}
private TipLabeledPolynomial recursivelyComputePolynomial(Tree tree, NodeRef node, Polynomial.Type type) {
if (tree.isExternal(node)) {
double[] value = new double[]{1.0};
double height = round(tree.getNodeHeight(node)); // Should help in numerical stability
return new TipLabeledPolynomial(value, height, type, true);
}
TipLabeledPolynomial childPolynomial1 = recursivelyComputePolynomial(tree, tree.getChild(node, 0), type);
TipLabeledPolynomial childPolynomial2 = recursivelyComputePolynomial(tree, tree.getChild(node, 1), type);
// TODO The partialPolynomial below *should* be cached in an efficient reuse scheme (at least for arbitrary precision)
TipLabeledPolynomial polynomial = childPolynomial1.multiply(childPolynomial2);
// See AbstractTreeLikelihood for an example of how to flag cached polynomials for re-evaluation
if (!tree.isRoot(node)) {
polynomial = polynomial.integrateWithLowerBound(polynomial.label);
}
return polynomial;
}
// private void test() {
//
// double[] value = new double[]{2, 0, 2};
// Polynomial a = new Polynomial.Double(value);
// Polynomial a2 = a.multiply(a);
// System.err.println("a :" + a);
// System.err.println("a*a: " + a2);
// System.err.println("eval :" + a2.evaluate(2));
// Polynomial intA = a.integrate();
// System.err.println("intA: " + intA);
// Polynomial intA2 = a.integrateWithLowerBound(2.0);
// System.err.println("intA2: " + intA2);
// System.err.println("");
//
// Polynomial b = new Polynomial.APDouble(value);
// System.err.println("b : " + b);
// Polynomial b2 = b.multiply(b);
// System.err.println("b2 : " + b2);
// System.err.println("eval : " + b2.evaluate(2));
// Polynomial intB = b.integrate();
// System.err.println("intB: " + intB);
// Polynomial intB2 = b.integrateWithLowerBound(2.0);
// System.err.println("intB2: " + intB2);
// System.err.println("");
//
// Polynomial c = new Polynomial.LogDouble(value);
// System.err.println("c : " + c);
// Polynomial c2 = c.multiply(c);
// System.err.println("c2 : " + c2);
// System.err.println("eval : " + c2.evaluate(2));
// Polynomial intC = c.integrate();
// System.err.println("intC: " + intC);
// Polynomial intC2 = c.integrateWithLowerBound(2.0);
// System.err.println("intC2: " + intC2);
// System.exit(-1);
// }
class TipLabeledPolynomial extends Polynomial.Abstract {
TipLabeledPolynomial(double[] coefficients, double label, Polynomial.Type type, boolean isTip) {
switch (type) {
case DOUBLE:
polynomial = new Polynomial.Double(coefficients);
break;
case LOG_DOUBLE:
polynomial = new Polynomial.LogDouble(coefficients);
break;
case BIG_DOUBLE:
polynomial = new Polynomial.BigDouble(coefficients);
break;
// case APDOUBLE: polynomial = new Polynomial.APDouble(coefficients);
// break;
// case RATIONAL: polynomial = new Polynomial.RationalDouble(coefficients);
// break;
// case MARCRATIONAL: polynomial = new Polynomial.MarcRational(coefficients);
// break;
default:
throw new RuntimeException("Unknown polynomial type");
}
this.label = label;
this.isTip = isTip;
}
TipLabeledPolynomial(Polynomial polynomial, double label, boolean isTip) {
this.polynomial = polynomial;
this.label = label;
this.isTip = isTip;
}
public TipLabeledPolynomial copy() {
Polynomial copyPolynomial = polynomial.copy();
return new TipLabeledPolynomial(copyPolynomial, this.label, this.isTip);
}
public Polynomial getPolynomial() {
return polynomial;
}
public TipLabeledPolynomial multiply(TipLabeledPolynomial b) {
double maxLabel = Math.max(label, b.label);
return new TipLabeledPolynomial(polynomial.multiply(b), maxLabel, false);
}
public int getDegree() {
return polynomial.getDegree();
}
public Polynomial multiply(Polynomial b) {
return polynomial.multiply(b);
}
public Polynomial integrate() {
return polynomial.integrate();
}
public void expand(double x) {
polynomial.expand(x);
}
public double evaluate(double x) {
return polynomial.evaluate(x);
}
public double logEvaluate(double x) {
return polynomial.logEvaluate(x);
}
public double logEvaluateHorner(double x) {
return polynomial.logEvaluateHorner(x);
}
public void setCoefficient(int n, double x) {
polynomial.setCoefficient(n, x);
}
public TipLabeledPolynomial integrateWithLowerBound(double bound) {
return new TipLabeledPolynomial(polynomial.integrateWithLowerBound(bound), label, isTip);
}
public double getCoefficient(int n) {
return polynomial.getCoefficient(n);
}
public String toString() {
return polynomial.toString() + " {" + label + "}";
}
public String getCoefficientString(int n) {
return polynomial.getCoefficientString(n);
}
private double label;
private Polynomial polynomial;
private boolean isTip;
}
private double logFactorial(int n) {
if (n == 0 || n == 1) {
return 0;
}
double rValue = 0;
for (int i = n; i > 1; i--) {
rValue += Math.log(i);
}
return rValue;
}
// **************************************************************
// XMLElement IMPLEMENTATION
// **************************************************************
public org.w3c.dom.Element createElement(org.w3c.dom.Document d) {
throw new RuntimeException("createElement not implemented");
}
// ****************************************************************
// Private and protected stuff
// ****************************************************************
/**
* The tree.
*/
Tree tree = null;
double logLikelihood;
private double storedLogLikelihood;
boolean likelihoodKnown = false;
private boolean storedLikelihoodKnown = false;
private boolean treePolynomialKnown = false;
private boolean storedTreePolynomialKnown = false;
private Polynomial treePolynomial;
private Polynomial[] treePolynomials;
private Polynomial storedTreePolynomial;
private double tmpLogLikelihood;
// private Iterator<Polynomial.Type> typeIterator = EnumSet.allOf(Polynomial.Type.class).iterator();
// private Polynomial.Type polynomialType = typeIterator.next();
private Polynomial.Type polynomialType;
private double[] logLikelihoods;
private double[][] drawNodeHeights;
private double[] minNodeHeights;
}