/*
* ContinuousTraitLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evolution.continuous;
import dr.evolution.io.NewickImporter;
import dr.evolution.tree.*;
import dr.matrix.Matrix;
import dr.matrix.MutableMatrix;
import dr.geo.math.SphericalPolarCoordinates;
import java.io.StringReader;
/**
* Calculates the likelihood of a trait and tree.
*
* @version $Id: ContinuousTraitLikelihood.java,v 1.8 2006/06/18 16:20:58 alexei Exp $
*
* @author Alexei Drummond
*/
public class ContinuousTraitLikelihood {
/**
* Calculates the likelihood of the traits on the given tree.
* @param tree this is the tree on which to calculate the continuous attributes likelihood
*/
public double calculateLikelihood(MutableTree tree, String[] attributes, Contrastable[] mles, double kappa) {
ContrastedTraitNode contrastNode = new ContrastedTraitNode(tree, tree.getRoot(), attributes);
contrastNode.calculateContrasts(kappa);
for (int i =0; i < mles.length; i++) {
mles[i] = contrastNode.getTraitValue(i);
}
return calculateTraitsLikelihood(contrastNode);
}
/**
* Calculates the likelihood of the contrasted node
* @param contrastNode this is the node that the likelihood is calculated for
*/
private double calculateTraitsLikelihood(ContrastedTraitNode contrastNode) {
int count = contrastNode.getTraitCount();
if (count == 1) return calculateSingleTraitLikelihood(contrastNode);
return calculateMultipleTraitsLikelihood(contrastNode, count);
}
/**
* Calculates the likelihood of the contrasted node
* @param contrastNode this is the node that the likelihood is calculated for
*/
private double calculateMultipleTraitsLikelihood(ContrastedTraitNode contrastNode, int traitCount) {
SimpleTree contrastTree = new SimpleTree(contrastNode);
double[][] w = new double[traitCount][traitCount];
for (int j =0; j < traitCount; j++) {
for (int k = j; k < traitCount; k++) {
double wjk = 0.0;
for (int i = 0; i < contrastTree.getInternalNodeCount(); i++) {
ContrastedTraitNode ctNode = (ContrastedTraitNode)contrastTree.getInternalNode(i);
wjk += (ctNode.contrast[j] * ctNode.contrast[k]) / ctNode.contrastVariance;
}
//System.out.println("w["+j+"]["+k+"]="+wjk);
wjk /= (double)contrastTree.getInternalNodeCount();
w[j][k] = wjk;
w[k][j] = wjk;
}
}
MutableMatrix answer = Matrix.Util.createMutableMatrix(new double[1][1]);
MutableMatrix temp = Matrix.Util.createMutableMatrix(w);
double detW = 0.0;
try {
detW = Matrix.Util.det(temp);
} catch (Matrix.NotSquareException nse) { nse.printStackTrace(System.out); }
//System.out.println("W matrix");
//System.out.println(temp);
//System.out.println("|W|=" + detW);
MutableMatrix invW = Matrix.Util.createMutableMatrix(w);
try {
Matrix.Util.invert(invW);
} catch (Matrix.NotSquareException nse) { nse.printStackTrace(System.out); }
//System.out.println("inverse of W matrix");
//System.out.println(invW);
double logL = 0.0;
int n = contrastTree.getInternalNodeCount() + 1;
for (int i =0; i < contrastTree.getInternalNodeCount(); i++) {
ContrastedTraitNode ctNode = (ContrastedTraitNode)contrastTree.getInternalNode(i);
double[] contrasts = ctNode.getTraitContrasts();
Matrix uT = Matrix.Util.createRowVector(contrasts);
Matrix u = Matrix.Util.createColumnVector(contrasts);
try {
Matrix.Util.product(invW, u, temp);
Matrix.Util.product(uT, temp, answer);
} catch (Matrix.WrongDimensionException wde) { wde.printStackTrace(System.out); }
logL += answer.getElement(0,0) / ctNode.getContrastVariance();
logL += traitCount * Math.log(ctNode.getContrastVariance());
}
// root variance
logL += traitCount * Math.log(contrastNode.getNodeVariance());
logL += n * Math.log(detW);
logL += n * traitCount * Math.log(2*Math.PI);
logL = -logL / 2.0;
//System.out.println("root node variance = " + contrastNode.getNodeVariance());
return logL;
}
/**
* Calculate the likelihood of a single continuous trait on the given tree.
*/
private double calculateSingleTraitLikelihood(ContrastedTraitNode contrastNode) {
SimpleTree contrastTree = new SimpleTree(contrastNode);
double s2 = 0.0;
double sssContrast = 0.0;
double slogCV = 0.0;
for (int i = 0; i < contrastTree.getInternalNodeCount(); i++) {
ContrastedTraitNode ctNode = (ContrastedTraitNode)contrastTree.getInternalNode(i);
double contrast = ctNode.getTraitContrasts()[0];
double cv = ctNode.getContrastVariance();
sssContrast += (contrast * contrast) / cv;
slogCV += Math.log(cv);
if (ctNode.isRoot()) {
slogCV += Math.log(ctNode.getNodeVariance());
}
}
double tl = 0.0;
for (int i = 0; i < contrastTree.getNodeCount(); i++) {
NodeRef node = contrastTree.getNode(i);
if (!contrastTree.isRoot(node)) {
tl += contrastTree.getBranchLength(node);
}
}
s2 = sssContrast / contrastTree.getInternalNodeCount();
int n = contrastTree.getInternalNodeCount() + 1;
double logL = n * Math.log(2.0*Math.PI*s2);
logL += slogCV;
logL += sssContrast / s2;
logL = -logL / 2.0;
return logL;
}
class ContrastedTraitNode extends SimpleNode {
public ContrastedTraitNode(MutableTree tree, NodeRef node, String[] attributeNames) {
init(tree, node, attributeNames.length);
if (!tree.isExternal(node)) {
if (tree.getChildCount(node) != 2) { throw new IllegalArgumentException("Tree must be strictly bifurcating!"); }
addChild(new ContrastedTraitNode(tree, tree.getChild(node, 0), attributeNames));
addChild(new ContrastedTraitNode(tree, tree.getChild(node, 1), attributeNames));
} else {
for (int i =0; i < attributeNames.length; i++) {
Object obj = tree.getNodeTaxon(node).getAttribute(attributeNames[i]);
if (obj == null) throw new IllegalArgumentException("attribute " + attributeNames[i] + " does not exist in " + tree.getTaxonId(node.getNumber()));
if (obj instanceof Number) {
traitValue[i] = new Continuous(((Number)obj).doubleValue());
} else if (obj instanceof String) {
traitValue[i] = new Continuous(Double.parseDouble((String)obj));
} else if (obj instanceof Continuous) {
traitValue[i] = (Continuous)obj;
} else if (obj instanceof SphericalPolarCoordinates) {
traitValue[i] = (SphericalPolarCoordinates)obj;
}
tree.setNodeAttribute(node,attributeNames[i], traitValue[i]);
}
}
this.traitNames = attributeNames;
}
private void init(MutableTree tree, NodeRef node, int traitCount) {
setHeight(tree.getNodeHeight(node));
setRate(tree.getNodeRate(node));
setId(tree.getTaxonId(node.getNumber()));
setNumber(node.getNumber());
setTaxon(tree.getNodeTaxon(node));
contrast = new double[traitCount];
contrastVariance = 0.0;
traitValue = new Contrastable[traitCount];
nodeVariance = 0.0;
this.tree = tree;
this.node = node;
}
public double[] getTraitContrasts() {
return contrast;
}
public double getContrastVariance() { return contrastVariance; }
public double getNodeVariance() { return nodeVariance; }
public Contrastable getTraitValue(int traitIndex) {
return traitValue[traitIndex];
}
public int getTraitCount() { return traitValue.length; }
/**
* Recursively calculate the contrast information for the continuous trait nodes.
*/
private void calculateContrasts(double kappa) {
if (!isExternal()) {
ContrastedTraitNode left = (ContrastedTraitNode)getChild(0);
ContrastedTraitNode right = (ContrastedTraitNode)getChild(1);
left.calculateContrasts(kappa);
right.calculateContrasts(kappa);
double leftNodeBranchVariance = left.nodeVariance + Math.pow(getHeight() - left.getHeight(),kappa);
double rightNodeBranchVariance = right.nodeVariance + Math.pow(getHeight() - right.getHeight(), kappa);
// calculate the contrast variances
contrastVariance = leftNodeBranchVariance + rightNodeBranchVariance;
// estimate the variance of the tree value at this node
nodeVariance = (leftNodeBranchVariance * rightNodeBranchVariance) / (leftNodeBranchVariance + rightNodeBranchVariance);
// calculate the weights for each child
double invVarLeft = 1.0 / leftNodeBranchVariance;
double invVarRight = 1.0 / rightNodeBranchVariance;
// estimate the contrasts for each trait
for (int i = 0; i < getTraitCount(); i++) {
// calculate the contrast
contrast[i] = left.traitValue[i].getDifference(right.traitValue[i]);
//left.traitValue[i] - right.traitValue[i];
// estimate the variance weighted mean of the two child observations of the ith trait
traitValue[i] = left.traitValue[i].getWeightedMean(invVarLeft, left.traitValue[i], invVarRight, right.traitValue[i]);
//traitValue[i] = (invVarLeft * left.traitValue[i] + invVarRight * right.traitValue[i]) / (invVarLeft + invVarRight);
tree.setNodeAttribute(node, traitNames[i], traitValue[i]);
}
}
}
// the contrast for each trait at this node.
private double[] contrast;
// the contrast variance at this node.
private double contrastVariance;
// the trait value at this node.
private Contrastable[] traitValue;
// the tree variance at this node.
private double nodeVariance;
// the original tree
private MutableTree tree;
private NodeRef node;
private String[] traitNames;
// @todo find out what should be done with this variable
// the sum of the log of the contrast variances below this node
// private double slContrastVariance = 0;
}
public static void main(String[] args) throws Exception {
String testTree = "((A:1, B:1):1,(C:1, D:1):1);";
NewickImporter newickImporter = new NewickImporter(new StringReader(testTree));
MutableTree tree = (MutableTree)newickImporter.importTree(null);
tree.setTaxonAttribute(0, "U1", new Continuous(1.10));
tree.setTaxonAttribute(1, "U1", new Continuous(1.95));
tree.setTaxonAttribute(2, "U1", new Continuous(3.15));
tree.setTaxonAttribute(3, "U1", new Continuous(4.39));
tree.setTaxonAttribute(0, "U2", new Continuous(5.2));
tree.setTaxonAttribute(1, "U2", new Continuous(3.8));
tree.setTaxonAttribute(2, "U2", new Continuous(3.1));
tree.setTaxonAttribute(3, "U2", new Continuous(1.95));
ContinuousTraitLikelihood ctLikelihood = new ContinuousTraitLikelihood();
Contrastable[] mles = new Contrastable[2];
double logL = ctLikelihood.calculateLikelihood(tree, new String[] {"U1", "U2"}, mles, 1.0);
System.out.println("logL = " + logL);
System.out.println("mle(trait1) = " + mles[0]);
System.out.println("mle(trait2) = " + mles[1]);
Contrastable[] mle = new Contrastable[1];
System.out.println("logL (trait1) = " + ctLikelihood.calculateLikelihood(tree, new String[] {"U1"}, mle, 1.0));
System.out.println("mle(trait1) = " + mle[0]);
System.out.println("logL (trait2) = " + ctLikelihood.calculateLikelihood(tree, new String[] {"U2"}, mle, 1.0));
System.out.println("mle(trait2) = " + mle[0]);
}
}