/* * RandomLocalClockModel.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.branchratemodel; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evomodel.tree.TreeModel; import dr.evomodel.tree.TreeParameterModel; import dr.evomodel.tree.randomlocalmodel.RandomLocalTreeVariable; import dr.evomodelxml.branchratemodel.RandomLocalClockModelParser; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.util.Author; import dr.util.Citable; import dr.util.Citation; import java.util.Collections; import java.util.List; import java.util.logging.Logger; /** * A model of rates evolving along a tree, such that at each node the rate may change or not depending on an * indicator which chooses whether the parent rate is inherited or a new rate begins on the branch above the node. * This model is implemented using stochastic variable selection. * * @author Alexei Drummond * @author Andrew Rambaut * @author Marc A. Suchard * @version $Id: DiscretizedBranchRates.java,v 1.11 2006/01/09 17:44:30 rambaut Exp $ */ public class RandomLocalClockModel extends AbstractBranchRateModel implements RandomLocalTreeVariable, Citable { public RandomLocalClockModel(TreeModel treeModel, Parameter meanRateParameter, Parameter rateIndicatorParameter, Parameter ratesParameter, boolean ratesAreMultipliers, double threshold) { super(RandomLocalClockModelParser.LOCAL_BRANCH_RATES); this.ratesAreMultipliers = ratesAreMultipliers; indicators = new TreeParameterModel(treeModel, rateIndicatorParameter, false); rates = new TreeParameterModel(treeModel, ratesParameter, false); if (Double.isNaN(threshold)) { // NaN == binary values rateIndicatorParameter.addBounds(new Parameter.DefaultBounds(1, 0, rateIndicatorParameter.getDimension())); this.threshold = 0.5; for (int i = 0; i < rateIndicatorParameter.getDimension(); ++i) { rateIndicatorParameter.setParameterValue(i, 0.0); } } else { rateIndicatorParameter.addBounds(new Parameter.DefaultBounds(Double.MAX_VALUE, -Double.MAX_VALUE, rateIndicatorParameter.getDimension())); this.threshold = threshold; } ratesParameter.addBounds(new Parameter.DefaultBounds(Double.MAX_VALUE, 0, ratesParameter.getDimension())); for (int i = 0; i < rateIndicatorParameter.getDimension(); i++) { ratesParameter.setParameterValue(i, 1.0); } this.meanRateParameter = meanRateParameter; addModel(treeModel); this.treeModel = treeModel; addModel(indicators); addModel(rates); if (meanRateParameter != null) addVariable(meanRateParameter); unscaledBranchRates = new double[treeModel.getNodeCount()]; Logger.getLogger("dr.evomodel").info(" indicator parameter name is '" + rateIndicatorParameter.getId() + "' with threshold = " + threshold); recalculateScaleFactor(); } /** * @param tree the tree * @param node the node to retrieve the variable of * @return the raw real-valued variable at this node */ public final double getVariable(Tree tree, NodeRef node) { return rates.getNodeValue(tree, node); } /** * @param tree the tree * @param node the node * @return true of the variable at this node is included in function, thus representing a change in the * function looking down the tree. */ public final boolean isVariableSelected(Tree tree, NodeRef node) { return indicators.getNodeValue(tree, node) > threshold; } public void handleModelChangedEvent(Model model, Object object, int index) { recalculationNeeded = true; fireModelChanged(); } protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { recalculationNeeded = true; fireModelChanged(); } protected void storeState() { } protected void restoreState() { recalculationNeeded = true; // recalculateScaleFactor(); // Careful about change here; seems fine ... but who knows without unit-tests } protected void acceptState() { } public double getBranchRate(final Tree tree, final NodeRef node) { if (recalculationNeeded) { recalculateScaleFactor(); recalculationNeeded = false; } return unscaledBranchRates[node.getNumber()] * scaleFactor; } private void calculateUnscaledBranchRates(TreeModel tree) { cubr(tree, tree.getRoot(), 1.0); } /** * This is a recursive function that does the work of * calculating the unscaled branch rates across the tree * taking into account the indicator variables. * * @param tree the tree * @param node the node * @param rate the rate of the parent node */ private void cubr(TreeModel tree, NodeRef node, double rate) { int nodeNumber = node.getNumber(); if (!tree.isRoot(node)) { if (isVariableSelected(tree, node)) { if (ratesAreMultipliers) { rate *= getVariable(tree, node); } else { rate = getVariable(tree, node); } } } unscaledBranchRates[nodeNumber] = rate; int childCount = tree.getChildCount(node); for (int i = 0; i < childCount; i++) { cubr(tree, tree.getChild(node, i), rate); } } private void recalculateScaleFactor() { calculateUnscaledBranchRates(treeModel); double timeTotal = 0.0; double branchTotal = 0.0; for (int i = 0; i < treeModel.getNodeCount(); i++) { NodeRef node = treeModel.getNode(i); if (!treeModel.isRoot(node)) { double branchInTime = treeModel.getNodeHeight(treeModel.getParent(node)) - treeModel.getNodeHeight(node); double branchLength = branchInTime * unscaledBranchRates[node.getNumber()]; timeTotal += branchInTime; branchTotal += branchLength; } } scaleFactor = timeTotal / branchTotal; if (meanRateParameter != null) scaleFactor *= meanRateParameter.getParameterValue(0); } // AR - as TreeParameterModels are now loggable, the indicator parameter should be logged // directly. // private static String[] attributeLabel = {"changed"}; // // public String[] getNodeAttributeLabel() { // return attributeLabel; // } // // public String[] getAttributeForNode(Tree tree, NodeRef node) { // // if (tree.isRoot(node)) { // return new String[]{"false"}; // } // // return new String[]{(isVariableSelected((TreeModel) tree, node) ? "true" : "false")}; // } // the scale factor necessary to maintain the mean rate private double scaleFactor; // the tree model private TreeModel treeModel; // true if the rate variables are treated as relative // to the parent rate rather than absolute rates private boolean ratesAreMultipliers = false; // the unscaled rates of each branch, taking into account the indicators private double[] unscaledBranchRates; // the mean rate across all the tree, if null then mean rate is scaled to 1.0 private Parameter meanRateParameter; private TreeParameterModel indicators; private TreeParameterModel rates; boolean recalculationNeeded = true; private final double threshold; @Override public Citation.Category getCategory() { return Citation.Category.MOLECULAR_CLOCK; } @Override public String getDescription() { return "Local clock model"; } @Override public List<Citation> getCitations() { return Collections.singletonList(CITATION); } public static Citation CITATION = new Citation( new Author[]{ new Author("AJ", "Drummond"), new Author("MA", "Suchard") }, "Bayesian random local clocks, or one rate to rule them all", 2010, "BMC Biology", "8: 114", "10.1186/1741-7007-8-114" ); }