/*
* DiscreteTraitBranchRateModel.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.branchratemodel;
import dr.evolution.alignment.PatternList;
import dr.evolution.datatype.DataType;
import dr.evolution.parsimony.FitchParsimony;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evolution.util.TaxonList;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
/**
* This Branch Rate Model takes a ancestral state likelihood and
* gives the rate for each branch of the tree based on the ancestor
* state and child state or based on the Markov Rewards (i.e., time
* spent in each state).
*
* @author Alexei Drummond
* @author Marc Suchard
* @author Andrew Rambaut
* @author Trevor Bedford
*/
public class DiscreteTraitBranchRateModel extends AbstractBranchRateModel {
private static final boolean CACHING_RATES = true;
enum Mode {
NODE_STATES,
MARKOV_JUMP_PROCESS,
MARKOV_JUMP_COUNT,
PARSIMONY
}
public static final String DISCRETE_TRAIT_BRANCH_RATE_MODEL = "discreteTraitRateModel";
protected TreeTrait trait = null;
private Parameter rateParameter;
private Parameter relativeRatesParameter;
private Parameter indicatorParameter;
protected int traitIndex;
private double[] rates;
private double[] storedRates;
private boolean[] rateKnown;
// private boolean normKnown = false;
// private boolean storedNormKnown = false;
// private double norm = 1.0;
// private double storedNorm = 1.0;
private TreeTrait[] traits;
private FitchParsimony fitchParsimony;
private boolean treeChanged = true;
private Mode mode;
private DataType dataType;
// private int treeInitializeCounter = 0;
/**
* A constructor for the (crude) parsimony reconstruction form of this class.
* @param treeModel
* @param patternList
* @param traitIndex
* @param ratesParameter
*/
public DiscreteTraitBranchRateModel(TreeModel treeModel, PatternList patternList, int traitIndex, Parameter ratesParameter) {
this(treeModel, traitIndex, ratesParameter, null, null);
if (!TaxonList.Utils.getTaxonListIdSet(treeModel).equals(TaxonList.Utils.getTaxonListIdSet(patternList))) {
throw new IllegalArgumentException("Tree model and pattern list must have the same list of taxa!");
}
ratesParameter.setDimension(patternList.getDataType().getStateCount());
fitchParsimony = new FitchParsimony(patternList, false);
mode = Mode.PARSIMONY;
}
/**
* A constructor for a node-sampled discrete trait
* @param treeModel
* @param trait
* @param traitIndex
* @param rateParameter
* @param relativeRatesParameter
* @param indicatorParameter
*/
public DiscreteTraitBranchRateModel(TreeTraitProvider traitProvider, DataType dataType, TreeModel treeModel,
TreeTrait trait, int traitIndex, Parameter rateParameter, Parameter relativeRatesParameter, Parameter indicatorParameter) {
this(treeModel, traitIndex, rateParameter, relativeRatesParameter, indicatorParameter);
// if (trait.getTreeModel() != treeModel)
// throw new IllegalArgumentException("Tree Models for ancestral state tree likelihood and target model of these rates must match!");
this.trait = trait;
this.dataType = dataType;
if (trait.getTraitName().equals("states")) {
// Assume the trait is one or more discrete traits reconstructed at nodes
mode = Mode.NODE_STATES;
} else /*if (double[].class.isAssignableFrom(trait.getClass()))*/ {
// Assume the trait itself is the dwell times for the individual states on the branch above the node
mode = Mode.MARKOV_JUMP_PROCESS;
} /* else {
throw new IllegalArgumentException("The trait class type is not suitable for use in this class.");
} */
relativeRatesParameter.setDimension(dataType.getStateCount());
if (traitProvider instanceof Model) {
addModel((Model)traitProvider);
}
if (trait instanceof Model) {
addModel((Model)trait); // MAS: Does this ever occur?
}
}
public DiscreteTraitBranchRateModel(TreeTraitProvider traitProvider, DataType dataType, TreeModel treeModel,
TreeTrait trait, int traitIndex, Parameter ratesParameter) {
this(treeModel, traitIndex, ratesParameter, null, null);
// if (trait.getTreeModel() != treeModel)
// throw new IllegalArgumentException("Tree Models for ancestral state tree likelihood and target model of these rates must match!");
this.trait = trait;
this.dataType = dataType;
if (trait.getTraitName().equals("states")) {
// Assume the trait is one or more discrete traits reconstructed at nodes
mode = Mode.NODE_STATES;
} else /*if (double[].class.isAssignableFrom(trait.getClass()))*/ {
// Assume the trait itself is the dwell times for the individual states on the branch above the node
mode = Mode.MARKOV_JUMP_PROCESS;
} /* else {
throw new IllegalArgumentException("The trait class type is not suitable for use in this class.");
} */
ratesParameter.setDimension(dataType.getStateCount());
if (traitProvider instanceof Model) {
addModel((Model)traitProvider);
}
if (trait instanceof Model) {
addModel((Model)trait); // MAS: Does this ever occur?
}
}
public DiscreteTraitBranchRateModel(TreeTraitProvider traitProvider, TreeTrait[] traits, TreeModel treeModel, Parameter ratesParameter) {
this(treeModel, 0, ratesParameter, null, null);
this.traits = traits;
mode = Mode.MARKOV_JUMP_PROCESS;
ratesParameter.setDimension(traits.length);
if (traitProvider instanceof Model) {
addModel((Model)traitProvider);
}
}
private DiscreteTraitBranchRateModel(TreeModel treeModel, int traitIndex,
Parameter rateParameter, Parameter relativeRatesParameter, Parameter indicatorParameter) {
super(DISCRETE_TRAIT_BRANCH_RATE_MODEL);
addModel(treeModel);
this.traitIndex = traitIndex;
this.rateParameter = rateParameter;
addVariable(rateParameter);
this.relativeRatesParameter = relativeRatesParameter;
if (relativeRatesParameter != null) {
addVariable(relativeRatesParameter);
}
this.indicatorParameter = indicatorParameter;
if (indicatorParameter != null) {
addVariable(indicatorParameter);
}
rates = new double[treeModel.getNodeCount()];
storedRates = new double[treeModel.getNodeCount()];
rateKnown = new boolean[treeModel.getNodeCount()];
}
public void handleModelChangedEvent(Model model, Object object, int index) {
// TreeModel has changed...
for (int i = 0; i < rateKnown.length; i++) {
rateKnown[i] = false;
}
treeChanged = true;
fireModelChanged();
}
protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
// Rate Parameters have changed
for (int i = 0; i < rateKnown.length; i++) {
rateKnown[i] = false;
}
fireModelChanged();
}
protected void storeState() {
if (CACHING_RATES) {
System.arraycopy(rates, 0, storedRates, 0, rates.length);
}
}
protected void restoreState() {
if (CACHING_RATES) {
double[] tmp = rates;
rates = storedRates;
storedRates = tmp;
for (int i = 0; i < rateKnown.length; i++) {
rateKnown[i] = true;
}
} else {
for (int i = 0; i < rateKnown.length; i++) {
rateKnown[i] = false;
}
treeChanged = true;
}
}
protected void acceptState() {
// nothing to do
}
protected int getStateCount() {
int dimen = 0;
if (mode == Mode.NODE_STATES || mode == Mode.MARKOV_JUMP_PROCESS) {
dimen = dataType.getStateCount();
}
else if (mode == Mode.PARSIMONY) {
dimen = fitchParsimony.getPatterns().getStateCount();
}
return dimen;
}
public double getBranchRate(final Tree tree, final NodeRef node) {
if (CACHING_RATES) {
if (!rateKnown[node.getNumber()]) {
rates[node.getNumber()] = getRawBranchRate(tree, node);
rateKnown[node.getNumber()] = true;
}
return rates[node.getNumber()];
} else {
return getRawBranchRate(tree, node);
}
}
// produce weighted mean of rate for a branch
// rate = absRate * branchWeight[0] * relativeRates[0] + absRate * branchWeight[1] * relativeRates[1]
protected double getRawBranchRate(final Tree tree, final NodeRef node) {
double rate = 0.0;
int stateCount = getStateCount();
double[] processValues = getProcessValues(tree, node);
// double[] processValues = {1.0, 1.0};
double[] branchWeights = new double[stateCount];
double totalTime = 0;
for (int i = 0; i < stateCount; i++) {
branchWeights[i] += processValues[i];
totalTime += processValues[i];
}
for (int i = 0; i < stateCount; i++) {
branchWeights[i] /= totalTime;
}
if (relativeRatesParameter != null && indicatorParameter == null) {
double absRate = rateParameter.getParameterValue(0);
for (int i = 0; i < stateCount; i++) {
rate += absRate * relativeRatesParameter.getParameterValue(i) * branchWeights[i];
}
}
else if (relativeRatesParameter != null && indicatorParameter != null) {
double absRate = rateParameter.getParameterValue(0);
for (int i = 0; i < stateCount; i++) {
rate += absRate * relativeRatesParameter.getParameterValue(i) * branchWeights[i] * (double) indicatorParameter.getParameterValue(i);
}
} else {
for (int i = 0; i < stateCount; i++) {
rate += rateParameter.getParameterValue(i) * processValues[i];
totalTime += processValues[i];
}
}
return rate;
}
/**
*
* @param tree
* @param node
* @return and array of the total amount of time spent in each of the discrete states along the branch above the given node.
*/
private double[] getProcessValues(final Tree tree, final NodeRef node) {
double[] processValues = null;
int stateCount = getStateCount();
double branchTime = tree.getBranchLength(node);
if (mode == Mode.MARKOV_JUMP_PROCESS) {
processValues = new double[stateCount];
for (int i = 0; i < stateCount; i++) {
processValues[i] = ((TreeTrait.DA)traits[i]).getTrait(tree, node)[0];
}
} else if (mode == Mode.PARSIMONY) {
// an approximation to dwell times using parsimony, assuming
// the state changes midpoint on the tree. Does a weighted
// average of the equally parsimonious state reconstructions
// at the top and bottom of each branch.
if (treeChanged) {
fitchParsimony.initialize(tree);
// Debugging test to count work
// treeInitializeCounter += 1;
// if (treeInitializeCounter % 10 == 0) {
// System.err.println("Cnt: "+treeInitializeCounter);
// }
treeChanged = false;
}
int[] states = fitchParsimony.getStates(tree, node);
int[] parentStates = fitchParsimony.getStates(tree, tree.getParent(node));
processValues = new double[fitchParsimony.getPatterns().getStateCount()];
for (int state : states) {
processValues[state] += branchTime / 2;
}
for (int state : parentStates) {
processValues[state] += branchTime / 2;
}
for (int i = 0; i < processValues.length; i++) {
// normalize by the number of equally parsimonious states at each end of the branch
// processValues should add up to the total branch length
processValues[i] /= (states.length + parentStates.length) / 2;
}
} else if (mode == Mode.NODE_STATES) {
processValues = new double[stateCount];
// if the states are being sampled - then there is only one possible state at each
// end of the branch.
int state = ((int[])trait.getTrait(tree, node))[traitIndex];
processValues[state] += branchTime / 2;
NodeRef parent = tree.getParent(node);
int parentState = ((int[])trait.getTrait(tree, parent))[traitIndex];
processValues[parentState] += branchTime / 2;
}
return processValues;
}
}