/* * RateSmoothingTree.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evolution.tree; import dr.math.ConjugateDirectionSearch; import dr.math.MultivariateFunction; import dr.math.MultivariateMinimum; /** * An extension of simple tree that provides non-parametric rate smoothed * node heights in the manner described by Mike Sanderson. This takes a * tree with branch lengths as data and uses multidimensional optimization * to find the set of node heights that requires the smallest differences * in rates across nodes. * * @version $Id: RateSmoothingTree.java,v 1.8 2005/05/24 20:25:56 rambaut Exp $ * * @author Andrew Rambaut * @author Alexei Drummond */ public class RateSmoothingTree extends SimpleTree { /** * constructor */ public RateSmoothingTree(Tree sourceTree) { super(sourceTree); this.sourceTree = sourceTree; this.mu = 1.0; this.optimizeMu = false; } /** * constructor with a specific mutation rate */ public RateSmoothingTree(Tree sourceTree, double mu) { this.sourceTree = sourceTree; this.mu = mu; this.optimizeMu = false; } public double getMu() { return mu; } public void smoothRates() { nodeCount = getInternalNodeCount(); int argumentCount = nodeCount; if (optimizeMu) { argumentCount++; muIndex = nodeCount; } MultivariateMinimum optimizer = new ConjugateDirectionSearch(); nodeValues = new double[nodeCount]; double[] xvec = new double[argumentCount]; for (int i = 0; i < nodeCount; i++) { xvec[i] = 1.0; } if (optimizeMu) { xvec[muIndex] = mu; } optimizer.optimize(nonParametricRateSmoothing, xvec, 1E-8, 1E-8); } public double getSumOfRates() { double[] score = new double[] { 0.0 }; NodeRef root = getRoot(); if (getChildCount(root) != 2) { throw new IllegalArgumentException("The tree must have a bifurcating root node"); } sumDist = 0; sumTime = 0; double rate1 = sumScoreAtNode(getChild(root, 0), score); double rate2 = sumScoreAtNode(getChild(root, 1), score); mu = sumDist / sumTime; double diff = rate2 - rate1; score[0] += diff * diff; return score[0]; } // // Private stuff // private double sumScoreAtNode(NodeRef node, double[] score) { double rate0 = getRateAtNode(node); if (!isExternal(node)) { for (int i = 0; i < getChildCount(node); i++) { double rate1 = sumScoreAtNode(getChild(node, i), score); double diff = rate0 - rate1; score[0] += diff * diff; } } return rate0; } /** * returns the rate on the branch to the node above. */ private double getRateAtNode(NodeRef node) { double time = getNodeHeight(getParent(node)) - getNodeHeight(node); double dist = sourceTree.getBranchLength(sourceTree.getNode(node.getNumber())); double rate; if (time == 0.0) { if (dist == 0.0) { rate = 1.0; } else { rate = Double.MIN_VALUE; } } else { rate = (dist / time); } sumDist += dist; sumTime += time; setNodeRate(node, rate); return rate; } private double setNodeHeightsFromValues(NodeRef node) { if (!isExternal(node)) { double maxHeight = setNodeHeightsFromValues(getChild(node, 0)); for (int i = 1; i < getChildCount(node); i++) { double height = setNodeHeightsFromValues(getChild(node, i)); if (height > maxHeight) maxHeight = height; } setNodeHeight(node, maxHeight + nodeValues[node.getNumber() - getExternalNodeCount()]); } return getNodeHeight(node); } private MultivariateFunction nonParametricRateSmoothing = new MultivariateFunction() { public double evaluate(double[] argument) { for (int i = 0; i < getInternalNodeCount(); i++) { nodeValues[i] = argument[i]; } setNodeHeightsFromValues(getRoot()); if (optimizeMu) { mu = argument[muIndex]; } double score = getSumOfRates(); return score; } public int getNumArguments() { if (optimizeMu) { return getInternalNodeCount() + 1; } else { return getInternalNodeCount(); } } public double getLowerBound(int n) { if (optimizeMu && n == muIndex) { return Double.MIN_VALUE; } else { return 0.0; } } public double getUpperBound(int n) { if (optimizeMu && n == muIndex) { return Double.MAX_VALUE; } else { return Double.MAX_VALUE; } } }; private int nodeCount; private double[] nodeValues; private Tree sourceTree; private double mu, sumDist, sumTime; private boolean optimizeMu; private int muIndex; }