/*
* TemporalRooting.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.app.pathogen;
import dr.evolution.tree.*;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.*;
import dr.evolution.util.Date;
import dr.stats.Regression;
import dr.stats.DiscreteStatistics;
import dr.math.*;
import java.util.*;
/*
* @author Andrew Rambaut
*/
public class TemporalRooting {
public enum RootingFunction {
HEURISTIC_RESIDUAL_MEAN_SQUARED("heuristic residual mean squared"),
RESIDUAL_MEAN_SQUARED("residual mean squared"),
// SUM_RESIDUAL_SQUARED("sum squared residuals"),
CORRELATION("correlation"),
R_SQUARED("R squared");
RootingFunction(final String name) {
this.name = name;
}
@Override
public String toString() {
return name;
}
private final String name;
}
private boolean contemporaneous = false;
private final TaxonList taxa;
private final Map<String, Double> dates;
private final Map<String, Double> precisions;
private boolean useTargetRate = false;
private double targetRate = 0.0;
private double dateMin;
private double dateMax;
private int currentRootBranch = 0;
private int totalRootBranches = 0;
private boolean forcePositiveRate = false;
public TemporalRooting(TaxonList taxa) {
this.taxa = taxa;
dates = new HashMap<String, Double>();
precisions = new HashMap<String, Double>();
dateMin = Double.MAX_VALUE;
dateMax = -Double.MAX_VALUE;
for (int i = 0; i < taxa.getTaxonCount(); i++) {
Taxon taxon = taxa.getTaxon(i);
Date date = (Date)taxon.getAttribute("date");
double d = 0.0;
if (date != null) {
d = date.getAbsoluteTimeValue();
if (date.getPrecision() > 0.0) {
d += date.getPrecision() / 2;
precisions.put(taxon.getId(), date.getPrecision());
}
}
if (d > dateMax) {
dateMax = d;
}
if (d < dateMin) {
dateMin = d;
}
dates.put(taxon.getId(), d);
}
if (Math.abs(dateMax - dateMin) < 1.0E-8) {
// probably contemporaneous tips
contemporaneous = true;
}
}
public void setForcePositiveRate(boolean forcePositiveRate) {
this.forcePositiveRate = forcePositiveRate;
}
public void setTargetRate(double targetRate) {
this.targetRate = targetRate;
}
public boolean isContemporaneous() {
return contemporaneous;
}
public double getDateRange() {
return dateMax - dateMin;
}
public Tree findRoot(Tree tree, RootingFunction rootingFunction) {
double[] dates = getTipDates(tree);
return findGlobalRoot(tree, dates, rootingFunction, forcePositiveRate);
}
public Tree findLocalRoot(Tree tree, RootingFunction rootingFunction) {
double[] dates = getTipDates(tree);
FlexibleTree bestTree = new FlexibleTree(tree);
findLocalRoot(bestTree, dates, rootingFunction, forcePositiveRate);
return bestTree;
}
public Regression getRootToTipRegression(Tree tree) {
if (contemporaneous) {
throw new IllegalArgumentException("Cannot do a root to tip regression on contemporaneous tips");
}
double[] dates = getTipDates(tree);
double[] distances = getRootToTipDistances(tree);
return new Regression(dates, distances);
}
public Regression getNodeDensityRegression(Tree tree) {
if (contemporaneous) {
throw new IllegalArgumentException("Cannot do a node density regression on contemporaneous tips");
}
double[] dates = getTipDates(tree);
// double[] distances = getRootToTipDistances(tree);
double[] density = getNodeDensity(tree);
return new Regression(dates, density);
}
public Regression getAncestorRootToTipRegression(Tree tree, Regression regression) {
if (contemporaneous) {
throw new IllegalArgumentException("Cannot do a root to tip regression on contemporaneous tips");
}
double[] dates = new double[tree.getExternalNodeCount()];
double[] distances = new double[tree.getExternalNodeCount()];
for (int i = 0; i < tree.getExternalNodeCount(); i++) {
NodeRef tip = tree.getExternalNode(i);
NodeRef parent = tree.getParent(tip);
distances[i] = getRootToTipDistance(tree, parent);
dates[i] = regression.getXIntercept() + (distances[i] / regression.getGradient());
}
return new Regression(dates, distances);
}
public double[] getRootToTipDistances(Tree tree) {
double[] d = new double[tree.getExternalNodeCount()];
for (int i = 0; i < tree.getExternalNodeCount(); i++) {
NodeRef tip = tree.getExternalNode(i);
d[i] = getRootToTipDistance(tree, tip);
}
return d;
}
public double[] getParentRootToTipDistances(Tree tree) {
double[] d = new double[tree.getExternalNodeCount()];
for (int i = 0; i < tree.getExternalNodeCount(); i++) {
NodeRef tip = tree.getExternalNode(i);
NodeRef parent = tree.getParent(tip);
d[i] = getRootToTipDistance(tree, parent);
}
return d;
}
public double[] getRootToTipResiduals(Tree tree, Regression regression) {
double[] r = new double[tree.getExternalNodeCount()];
for (int i = 0; i < tree.getExternalNodeCount(); i++) {
NodeRef tip = tree.getExternalNode(i);
Double date = dates.get(tree.getNodeTaxon(tip).getId());
double d = getRootToTipDistance(tree, tip);
r[i] = regression.getResidual(date, d);
}
return r;
}
public double[] getNodeDensity(Tree tree) {
double[] d = new double[tree.getExternalNodeCount()];
for (int i = 0; i < tree.getExternalNodeCount(); i++) {
NodeRef tip = tree.getExternalNode(i);
d[i] = getNodeDensity(tree, tip);
}
return d;
}
public double[] getTipDates(Tree tree) {
double[] d = new double[tree.getExternalNodeCount()];
for (int i = 0; i < tree.getExternalNodeCount(); i++) {
NodeRef tip = tree.getExternalNode(i);
Double date = dates.get(tree.getNodeTaxon(tip).getId());
if (date == null) {
throw new IllegalArgumentException("Taxon, " + tree.getNodeTaxon(tip) + ", not found in taxon list");
}
d[i] = date;
}
return d;
}
public double[] getTipDatePrecisions(Tree tree) {
double[] p = new double[tree.getExternalNodeCount()];
for (int i = 0; i < tree.getExternalNodeCount(); i++) {
NodeRef tip = tree.getExternalNode(i);
Double precision = precisions.get(tree.getNodeTaxon(tip).getId());
if (precision == null) {
precision = 0.0;
}
p[i] = precision;
}
return p;
}
public String[] getTipLabels(Tree tree) {
String[] labels = new String[tree.getExternalNodeCount()];
for (int i = 0; i < tree.getExternalNodeCount(); i++) {
NodeRef tip = tree.getExternalNode(i);
labels[i] = tree.getNodeTaxon(tip).getId();
}
return labels;
}
private Tree findGlobalRoot(final Tree source, final double[] dates, RootingFunction rootingFunction, boolean forcePositiveRate) {
FlexibleTree bestTree = new FlexibleTree(source);
double minF = findLocalRoot(bestTree, dates, rootingFunction, forcePositiveRate);
double minDiff = Double.MAX_VALUE;
totalRootBranches = source.getNodeCount();
for (currentRootBranch = 0; currentRootBranch < source.getNodeCount(); currentRootBranch++) {
FlexibleTree tmpTree = new FlexibleTree(source);
NodeRef node = tmpTree.getNode(currentRootBranch);
if (!tmpTree.isRoot(node)) {
double length = tmpTree.getBranchLength(node);
tmpTree.changeRoot(node, length * 0.5, length * 0.5);
double f = findLocalRoot(tmpTree, dates, rootingFunction, forcePositiveRate);
if (useTargetRate) {
Regression r = getRootToTipRegression(tmpTree);
if (Math.abs(r.getGradient() - targetRate) < minDiff) {
minDiff = Math.abs(r.getGradient() - targetRate);
bestTree = tmpTree;
}
} else {
if (f < minF) {
minF = f;
bestTree = tmpTree;
}
}
}
}
return bestTree;
}
private double findLocalRoot(final FlexibleTree tree,
final double[] dates,
final RootingFunction rootingFunction,
final boolean forcePositiveRate) {
if (rootingFunction == RootingFunction.RESIDUAL_MEAN_SQUARED) {
return findAnalyticalLocalRoot(tree, dates, rootingFunction);
}
NodeRef node1 = tree.getChild(tree.getRoot(), 0);
NodeRef node2 = tree.getChild(tree.getRoot(), 1);
final double length1 = tree.getBranchLength(node1);
final double length2 = tree.getBranchLength(node2);
final double sumLength = length1 + length2;
final Set<NodeRef> tipSet1 = TreeUtils.getExternalNodes(tree, node1);
final Set<NodeRef> tipSet2 = TreeUtils.getExternalNodes(tree, node2);
final double[] y = new double[tree.getExternalNodeCount()];
UnivariateFunction f = new UnivariateFunction() {
// MultivariateFunction f = new MultivariateFunction() {
public double evaluate(final double argument) {
double l1 = argument * sumLength;
for (NodeRef tip : tipSet1) {
y[tip.getNumber()] = getRootToTipDistance(tree, tip) - length1 + l1;
}
double l2 = (1.0 - argument) * sumLength;
for (NodeRef tip : tipSet2) {
y[tip.getNumber()] = getRootToTipDistance(tree, tip) - length2 + l2;
}
double score;
if (!contemporaneous) {
Regression r = new Regression(dates, y);
switch (rootingFunction) {
case CORRELATION:
score = -r.getCorrelationCoefficient();
break;
case R_SQUARED:
score = -r.getRSquared();
break;
case HEURISTIC_RESIDUAL_MEAN_SQUARED:
case RESIDUAL_MEAN_SQUARED:
score = r.getResidualMeanSquared();
break;
default:
throw new RuntimeException("Unknown enum value");
}
if (forcePositiveRate) {
score = (r.getGradient() < 0.0 ? -score : score);
}
} else {
score = DiscreteStatistics.variance(y);
}
return score;
}
public int getNumArguments() {
return 1;
}
public double getLowerBound() {
return 0;
}
public double getUpperBound() {
return 1.0;
}
};
// DifferentialEvolution minimum = new DifferentialEvolution(1);
// ConjugateDirectionSearch minimum = new ConjugateDirectionSearch();
// double[] minx = new double[] { 0.5 };
//
// double fminx = minimum.findMinimum(f, minx);
// double x = minx[0];
UnivariateMinimum minimum = new UnivariateMinimum();
double x = minimum.findMinimum(f);
double fminx = minimum.fminx;
double l1 = x * sumLength;
double l2 = (1.0 - x) * sumLength;
tree.setBranchLength(node1, l1);
tree.setBranchLength(node2, l2);
return fminx;
}
private double findAnalyticalLocalRoot(final FlexibleTree tree,
final double[] t,
final RootingFunction rootingFunction) {
if (rootingFunction != RootingFunction.RESIDUAL_MEAN_SQUARED) {
throw new UnsupportedOperationException("Analytical local root solution only for residual mean squared");
}
NodeRef node1 = tree.getChild(tree.getRoot(), 0);
NodeRef node2 = tree.getChild(tree.getRoot(), 1);
final double length1 = tree.getBranchLength(node1);
final double length2 = tree.getBranchLength(node2);
final double sumLength = length1 + length2;
final Set<NodeRef> tipSet1 = TreeUtils.getExternalNodes(tree, node1);
final Set<NodeRef> tipSet2 = TreeUtils.getExternalNodes(tree, node2);
int N = tipSet1.size() + tipSet2.size();
int n = tipSet2.size();
final double[] c = new double[N];
for (NodeRef tip : tipSet2) {
int i = tip.getNumber();
c[i] = 1;
}
final double[] y = getRootToTipDistances(tree);
for (int j = 0; j < y.length; j++) { // little fiddling with the root-to-tip divergences to get the right input vector
y[j] = y[j] + (1-c[j])*(sumLength-length1) - c[j]*(sumLength-length1);
}
double sum_tt = 0.0;
double sum_t = 0.0;
double sum_y = 0.0;
double sum_ty = 0.0;
double sum_tc = 0.0;
double Nd = N;
double nd = n; // need to set these naughty guys to doubles
for (int i = 0; i < N; i++) {
sum_tt += t[i] * t[i];
sum_t += t[i];
sum_y += y[i];
sum_ty += t[i] * y[i];
sum_tc += t[i] * c[i];
}
double y_bar = sum_y / Nd;
double t_bar = sum_t / Nd;
double C = sum_tt - (sum_t * sum_t / Nd);
double sumAB = 0.0;
double sumAA = 0.0;
for (int i = 0; i < N; i++) {
double Ai = 2*c[i] -
((2*nd-Nd)/Nd) +
(2*(t_bar-t[i])/(C*Nd))*(Nd*sum_tc - nd*sum_t) - 1;
double Bi = (y[i] - y_bar)
+ ((t_bar - t[i]) / (C * Nd)) * ((Nd * sum_ty) - (sum_t * sum_y));
sumAB += Ai * Bi;
sumAA += Ai * Ai;
}
double x = -sumAB / (sumLength * sumAA);
x = Math.min(Math.max(x, 0.0), 1.0);
double l1 = (1.0 - x) * sumLength;
double l2 = x * sumLength;
tree.setBranchLength(node1, l1);
tree.setBranchLength(node2, l2);
Regression r = new Regression(t, getRootToTipDistances(tree));
return r.getResidualMeanSquared();
}
public double getRootToTipDistance(Tree tree, NodeRef node) {
double distance = 0;
while (node != null) {
distance += tree.getBranchLength(node);
node = tree.getParent(node);
}
return distance;
}
public double getNodeDensity(Tree tree, NodeRef node) {
double density = 0;
while (node != null) {
density ++;
node = tree.getParent(node);
}
return density;
}
public Tree adjustTreeToConstraints(Tree source, Map<Set<String>, double[]> cladeHeights) {
FlexibleTree tree = new FlexibleTree(source);
setHeightsFromDates(tree);
adjustTreeToConstraints(tree, tree.getRoot(), null, cladeHeights);
return tree;
}
public int getCurrentRootBranch() {
return currentRootBranch;
}
public int getTotalRootBranches() {
return totalRootBranches;
}
private double adjustTreeToConstraints(FlexibleTree tree, NodeRef node,
Set<String> leaves,
Map<Set<String>, double[]> cladeHeights) {
if (!tree.isExternal(node)) {
Set<String> l = new HashSet<String>();
double maxChildHeight = 0.0;
for (int i = 0; i < tree.getChildCount(node); i++) {
NodeRef child = tree.getChild(node, i);
double h = adjustTreeToConstraints(tree, child, l, cladeHeights);
if (h > maxChildHeight) {
maxChildHeight = h;
}
}
double height = tree.getNodeHeight(node);
double lower = maxChildHeight;
double upper = Double.POSITIVE_INFINITY;
if (cladeHeights != null) {
for (Set<String> clade : cladeHeights.keySet()) {
if (clade.equals(l)) {
double[] bounds = cladeHeights.get(clade);
lower = Math.max(bounds[0], maxChildHeight);
upper = bounds[1];
}
}
}
if (lower > upper) {
throw new IllegalArgumentException("incompatible constraints");
}
if (height < lower) {
height = lower + 1E-6;
} else if (height > upper) {
height = (upper + lower) / 2;
}
tree.setNodeHeight(node, height);
if (leaves != null) {
leaves.addAll(l);
}
} else {
leaves.add(tree.getNodeTaxon(node).getId());
}
return tree.getNodeHeight(node);
}
private void setHeightsFromDates(FlexibleTree tree) {
dr.evolution.util.Date mostRecent = null;
for (int i = 0; i < taxa.getTaxonCount(); i++) {
Date date = taxa.getTaxon(i).getDate();
if ((date != null) && (mostRecent == null || date.after(mostRecent))) {
mostRecent = date;
}
}
if (mostRecent != null) {
TimeScale timeScale = new TimeScale(mostRecent.getUnits(), true, mostRecent.getAbsoluteTimeValue());
double time0 = timeScale.convertTime(mostRecent.getTimeValue(), mostRecent);
for (int i = 0; i < tree.getExternalNodeCount(); i++) {
NodeRef tip = tree.getExternalNode(i);
Date date = tree.getNodeTaxon(tip).getDate();
if (date != null) {
tree.setNodeHeight(tip, timeScale.convertTime(date.getTimeValue(), date) - time0);
} else {
tree.setNodeHeight(tip, 0.0);
}
}
}
}
}