/*
* ARGDiscretizedBranchRates.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.arg.branchratemodel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.arg.ARGModel;
import dr.evomodel.branchratemodel.AbstractBranchRateModel;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.inference.distribution.ParametricDistributionModel;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.xml.*;
import java.util.logging.Logger;
/**
* @author Marc Suchard
* @author Alexei Drummond
* @author Andrew Rambaut
* @version $Id: ARGDiscretizedBranchRates.java,v 1.11 2006/01/09 17:44:30 rambaut Exp $
*/
public class ARGDiscretizedBranchRates extends AbstractBranchRateModel {
public static final String DISCRETIZED_BRANCH_RATES = "argDiscretizedBranchRates";
public static final String DISTRIBUTION = "distribution";
public static final String NUM_RATE_CATEGORIES = "numRateCategories";
public static final String SINGLE_ROOT_RATE = "singleRootRate";
private ParametricDistributionModel distributionModel;
private ARGModel tree;
// The rate categories of each branch
private Parameter rateCategoryParameter;
// the index of the root node.
private int rootNodeNumber;
private int storedRootNodeNumber;
private final int categoryCount;
private final double step;
private final double[] rates;
private boolean ratesKnown = false;
// private boolean orderKnown = false;
public ARGDiscretizedBranchRates(ARGModel tree,
// Parameter rateCategoryParameter,
int numRateCategories,
ParametricDistributionModel model) {
super(DISCRETIZED_BRANCH_RATES);
this.tree = tree;
// categoryCount = tree.getNodeCount();
categoryCount = numRateCategories;
step = 1.0 / (double) categoryCount;
rates = new double[categoryCount];
this.distributionModel = model;
this.rateCategoryParameter = rateCategoryParameter;
// if (rateCategoryParameter.getDimension() != tree.getNodeCount() -1 ) {
// throw new IllegalArgumentException("The rate category parameter must be of length nodeCount-1");
// }
// if (rateCategoryParameter.getDimension() > tree.getNodeCount() -1 ) {
if (numRateCategories > tree.getNodeCount() - 1) {
throw new IllegalArgumentException("The rate category parameter must be less than the length 2*tipCount-1");
}
// for (int i = 0; i < rateCategoryParameter.getDimension(); i++) {
// rateCategoryParameter.setParameterValue(i, i);
// }
ratesKnown = false;
// orderKnown = false;
addModel(model);
addModel(tree);
// addVariable(rateCategoryParameter);
rootNodeNumber = tree.getRoot().getNumber();
storedRootNodeNumber = rootNodeNumber;
}
public void handleModelChangedEvent(Model model, Object object, int index) {
if (model == distributionModel) {
ratesKnown = false;
} else if (model == tree) {
// orderKnown = false;
}
fireModelChanged();
}
protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
fireModelChanged();
}
protected void storeState() {
storedRootNodeNumber = rootNodeNumber;
}
protected void restoreState() {
ratesKnown = false;
rootNodeNumber = storedRootNodeNumber;
}
protected void acceptState() {
}
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public String getParserName() {
return DISCRETIZED_BRANCH_RATES;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
ARGModel tree = (ARGModel) xo.getChild(ARGModel.class);
ParametricDistributionModel distributionModel = (ParametricDistributionModel) xo.getChild(DISTRIBUTION);
// Parameter rateCategoryParameter = (Parameter)xo.getSocketChild(RATE_CATEGORIES);
int numRateCategories = xo.getIntegerAttribute(NUM_RATE_CATEGORIES);
Logger.getLogger("dr.evomodel").info("Using discretized relaxed clock model.");
Logger.getLogger("dr.evomodel").info(" parametric model = " + distributionModel.getModelName());
Logger.getLogger("dr.evomodel").info(" rate categories = " + numRateCategories);
if (xo.hasAttribute(SINGLE_ROOT_RATE)) {
//singleRootRate = xo.getBooleanAttribute(SINGLE_ROOT_RATE);
Logger.getLogger("dr.evomodel").warning(" WARNING: single root rate is not implemented!");
}
return new ARGDiscretizedBranchRates(tree, numRateCategories, distributionModel);
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return
"This element returns an discretized relaxed clock model." +
"The branch rates are drawn from a discretized parametric distribution.";
}
public Class getReturnType() {
return ARGDiscretizedBranchRates.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{
// AttributeRule.newBooleanRule(SINGLE_ROOT_RATE, true, "Whether only a single rate should be used for the two children branches of the root"),
new ElementRule(ARGModel.class),
new ElementRule(DISTRIBUTION, ParametricDistributionModel.class, "The distribution model for rates among branches", false),
// new ElementRule(RATE_CATEGORIES, Parameter.class, "The rate categories parameter", false),
AttributeRule.newIntegerRule(NUM_RATE_CATEGORIES),
};
};
public double getBranchRate(Tree tree, NodeRef node) {
if (tree.isRoot(node)) {
throw new IllegalArgumentException("root node doesn't have a rate!");
}
if (!ratesKnown) {
setupRates();
ratesKnown = true;
}
// if (!orderKnown) {
// shuffleIndices();
// orderKnown = true;
// }
// int nodeNumber = node.getNumber();
// int rateCategory = 0;
// if (nodeNumber < rootNodeNumber) {
// rateCategory = (int)Math.round(rateCategoryParameter.getParameterValue(nodeNumber));
// } else if (nodeNumber > rootNodeNumber) {
// rateCategory = (int)Math.round(rateCategoryParameter.getParameterValue(nodeNumber-1));
// } else {
// throw new IllegalArgumentException("INTERNAL ERROR! node with number " + rootNodeNumber + " should be the root node.");
// }
int rateCategory = (int) tree.getNodeRate(node);
// System.err.println("Node "+nodeNumber+" has rate category "+rateCategory);
// System.err.println("rate = "+rates[rateCategory]+" : "+rateCategory);
return rates[rateCategory];
}
/**
* Calculates the actual rates corresponding to the category indices.
*/
private void setupRates() {
// System.err.println("Setting up rates:");
// System.err.println("catCount = "+categoryCount);
double z = step / 2.0;
for (int i = 0; i < categoryCount; i++) {
rates[i] = distributionModel.quantile(z);
z += step;
// System.err.println(rates[i]);
}
}
private void shuffleIndices() {
int newRootNodeNumber = tree.getRoot().getNumber();
//if (newRootNodeNumber != rootNodeNumber) {
// System.out.println("old root node number =" + rootNodeNumber);
// System.out.println("new root node number =" + newRootNodeNumber);
//}
if (rootNodeNumber > newRootNodeNumber) {
//for (int i = 0; i < rateCategoryParameter.getDimension(); i++) {
// System.out.print((int)Math.round(rateCategoryParameter.getParameterValue(i)) + "\t");
//}
//System.out.println();
int oldRateIndex = (int) Math.round(
rateCategoryParameter.getParameterValue(newRootNodeNumber));
int end = Math.min(rateCategoryParameter.getDimension() - 1, rootNodeNumber);
for (int i = newRootNodeNumber; i < end; i++) {
rateCategoryParameter.setParameterValue(i, rateCategoryParameter.getParameterValue(i + 1));
}
rateCategoryParameter.setParameterValue(end, oldRateIndex);
//for (int i = 0; i < rateCategoryParameter.getDimension(); i++) {
// System.out.print((int)Math.round(rateCategoryParameter.getParameterValue(i)) + "\t");
//}
//System.out.println();
} else if (rootNodeNumber < newRootNodeNumber) {
//System.out.println("old root node number =" + rootNodeNumber);
//System.out.println("new root node number =" + newRootNodeNumber);
//for (int i = 0; i < rateCategoryParameter.getDimension(); i++) {
// System.out.print((int)Math.round(rateCategoryParameter.getParameterValue(i)) + "\t");
//}
//System.out.println();
int end = Math.min(rateCategoryParameter.getDimension() - 1, newRootNodeNumber);
int oldRateIndex = (int) Math.round(
rateCategoryParameter.getParameterValue(end));
for (int i = end; i > rootNodeNumber; i--) {
rateCategoryParameter.setParameterValue(i, rateCategoryParameter.getParameterValue(i - 1));
}
rateCategoryParameter.setParameterValue(rootNodeNumber, oldRateIndex);
//for (int i = 0; i < rateCategoryParameter.getDimension(); i++) {
// System.out.print((int)Math.round(rateCategoryParameter.getParameterValue(i)) + "\t");
//}
//System.out.println();
}
rootNodeNumber = newRootNodeNumber;
}
}