/*
* SericolaLatentStateBranchRateModel.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.branchratemodel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.tree.TreeParameterModel;
import dr.inference.markovjumps.MarkovReward;
import dr.inference.markovjumps.TwoStateOccupancyMarkovReward;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import java.util.ArrayList;
import java.util.List;
/**
* SericolaLatentStateBranchRateModel
*
* @author Andrew Rambaut
* @author Marc Suchard
* @version $Id$
* <p/>
* $HeadURL$
* <p/>
* $LastChangedBy$
* $LastChangedDate$
* $LastChangedRevision$
*/
public class SericolaLatentStateBranchRateModel extends AbstractModelLikelihood implements BranchRateModel {
public static final String LATENT_STATE_BRANCH_RATE_MODEL = "latentStateBranchRateModel";
public static final boolean USE_CACHING = true;
// seed 666, caching off: 204.69 seconds for 20000 states
// state 20000 -5510.2520
// 85.7% 5202 + 6 dr.inference.markovjumps.SericolaSeriesMarkovReward.accumulatePdf
// seed 666, caching on: 119.43 seconds for 20000 states
// state 20000 -5510.2520
// 83.4% 3156 + 4 dr.inference.markovjumps.SericolaSeriesMarkovReward.accumulatePdf
private final TreeModel tree;
private final BranchRateModel nonLatentRateModel;
private final Parameter latentTransitionRateParameter;
private final Parameter latentTransitionFrequencyParameter;
private final TreeParameterModel latentStateProportions;
private final Parameter latentStateProportionParameter;
private final CountableBranchCategoryProvider branchCategoryProvider;
private MarkovReward series;
private MarkovReward storedSeries;
private boolean likelihoodKnown = false;
private boolean storedLikelihoodKnown;
private double logLikelihood;
private double storedLogLikelihood;
private double[] branchLikelihoods;
private double[] storedbranchLikelihoods;
private boolean[] updateBranch;
private boolean[] storedUpdateBranch;
private boolean[] updateCategory;
private boolean[] storedUpdateCategory;
public SericolaLatentStateBranchRateModel(String name,
TreeModel treeModel,
BranchRateModel nonLatentRateModel,
Parameter latentTransitionRateParameter,
Parameter latentTransitionFrequencyParameter,
Parameter latentStateProportionParameter,
CountableBranchCategoryProvider branchCategoryProvider) {
super(name);
this.tree = treeModel;
addModel(tree);
this.nonLatentRateModel = nonLatentRateModel;
addModel(nonLatentRateModel);
this.latentTransitionRateParameter = latentTransitionRateParameter;
addVariable(latentTransitionRateParameter);
this.latentTransitionFrequencyParameter = latentTransitionFrequencyParameter;
addVariable(latentTransitionFrequencyParameter);
if (branchCategoryProvider == null) {
this.latentStateProportions = new TreeParameterModel(tree, latentStateProportionParameter, false, Intent.BRANCH);
addModel(latentStateProportions);
this.latentStateProportionParameter = null;
this.branchCategoryProvider = null;
} else {
this.latentStateProportions = null;
this.branchCategoryProvider = branchCategoryProvider;
this.latentStateProportionParameter = latentStateProportionParameter;
this.latentStateProportionParameter.setDimension(branchCategoryProvider.getCategoryCount());
if (USE_CACHING) {
updateCategory = new boolean[branchCategoryProvider.getCategoryCount()];
storedUpdateCategory = new boolean[branchCategoryProvider.getCategoryCount()];
setUpdateAllCategories();
}
addVariable(latentStateProportionParameter);
}
branchLikelihoods = new double[tree.getNodeCount()];
if (USE_CACHING) {
updateBranch = new boolean[tree.getNodeCount()];
storedUpdateBranch = new boolean[tree.getNodeCount()];
storedbranchLikelihoods = new double[tree.getNodeCount()];
setUpdateAllBranches();
}
}
public SericolaLatentStateBranchRateModel(Parameter rate, Parameter prop) {
super(LATENT_STATE_BRANCH_RATE_MODEL);
tree = null;
nonLatentRateModel = null;
latentTransitionRateParameter = rate;
latentTransitionFrequencyParameter = prop;
latentStateProportions = null;
this.latentStateProportionParameter = null;
this.branchCategoryProvider = null;
}
private double[] createLatentInfinitesimalMatrix() {
final double rate = latentTransitionRateParameter.getParameterValue(0);
final double prop = latentTransitionFrequencyParameter.getParameterValue(0);
double[] mat = new double[]{
-rate * prop, rate * prop,
rate * (1.0 - prop), -rate * (1.0 - prop)
};
return mat;
}
private static double[] createReward() {
return new double[]{0.0, 1.0};
}
private MarkovReward createSeries() {
// MarkovReward series = new SericolaSeriesMarkovReward(createLatentInfinitesimalMatrix(),
// createReward(), 2);
MarkovReward series = new TwoStateOccupancyMarkovReward(createLatentInfinitesimalMatrix());
return series;
}
@Override
public double getBranchRate(Tree tree, NodeRef node) {
double nonLatentRate = nonLatentRateModel.getBranchRate(tree, node);
double latentProportion = getLatentProportion(tree, node);
return calculateBranchRate(nonLatentRate, latentProportion);
}
public double getLatentProportion(Tree tree, NodeRef node) {
if (latentStateProportions != null) {
return latentStateProportions.getNodeValue(tree, node);
} else {
return latentStateProportionParameter.getParameterValue(branchCategoryProvider.getBranchCategory(tree, node));
}
}
private double calculateBranchRate(double nonLatentRate, double latentProportion) {
return nonLatentRate * (1.0 - latentProportion);
}
@Override
protected void handleModelChangedEvent(Model model, Object object, int index) {
if (model == tree) {
likelihoodKnown = false; // node heights change elapsed times on branches, TODO could cache
if (index == -1) {
setUpdateAllBranches();
} else {
setUpdateBranch(index);
}
} else if (model == nonLatentRateModel) {
// rates will change but the latent proportions haven't so the density is unchanged
} else if (model == latentStateProportions) {
likelihoodKnown = false; // argument of density has changed
if (index == -1) {
setUpdateAllBranches();
} else {
setUpdateBranch(index);
}
}
fireModelChanged();
}
@Override
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
if (variable == latentTransitionFrequencyParameter || variable == latentTransitionRateParameter) {
// series computations have changed
series = null;
setUpdateAllBranches();
likelihoodKnown = false;
} else if (variable == latentStateProportionParameter) {
if (index == -1) {
setUpdateAllBranches();
} else {
setUpdateBranchCategory(index);
}
likelihoodKnown = false;
fireModelChanged();
}
}
private void setUpdateBranch(int nodeNumber) {
if (USE_CACHING) {
updateBranch[nodeNumber] = true;
}
}
private void setUpdateAllBranches() {
if (USE_CACHING) {
for (int i = 0; i < updateBranch.length; i++) {
updateBranch[i] = true;
}
}
}
private void clearUpdateAllBranches() {
if (USE_CACHING) {
for (int i = 0; i < updateBranch.length; i++) {
updateBranch[i] = false;
}
}
}
private void setUpdateBranchCategory(int category) {
if (USE_CACHING) {
updateCategory[category] = true;
}
}
private void setUpdateAllCategories() {
if (USE_CACHING) {
for (int i = 0; i < updateCategory.length; i++) {
updateCategory[i] = true;
}
}
}
private void clearAllCategories() {
if (USE_CACHING && updateCategory != null) {
for (int i = 0; i < updateCategory.length; i++) {
updateCategory[i] = false;
}
}
}
@Override
protected void storeState() {
storedSeries = series;
storedLogLikelihood = logLikelihood;
storedLikelihoodKnown = likelihoodKnown;
if (USE_CACHING) {
System.arraycopy(branchLikelihoods, 0, storedbranchLikelihoods, 0, branchLikelihoods.length);
System.arraycopy(updateBranch, 0, storedUpdateBranch, 0, updateBranch.length);
if (updateCategory != null) {
System.arraycopy(updateCategory, 0, storedUpdateCategory, 0, updateCategory.length);
}
}
}
@Override
protected void restoreState() {
series = storedSeries;
logLikelihood = storedLogLikelihood;
likelihoodKnown = storedLikelihoodKnown;
if (USE_CACHING) {
double[] tmp = branchLikelihoods;
branchLikelihoods = storedbranchLikelihoods;
storedbranchLikelihoods = tmp;
boolean[] tmp2 = updateBranch;
updateBranch = storedUpdateBranch;
storedUpdateBranch = tmp2;
boolean[] tmp3 = updateCategory;
updateCategory = storedUpdateCategory;
storedUpdateCategory = tmp3;
}
}
@Override
protected void acceptState() {
}
@Override
public Model getModel() {
return this;
}
@Override
public double getLogLikelihood() {
if (!likelihoodKnown) {
logLikelihood = calculateLogLikelihood();
likelihoodKnown = true;
}
return logLikelihood;
}
private double calculateLogLikelihood() {
double logLike = 0.0;
for (int i = 0; i < tree.getInternalNodeCount(); ++i) {
NodeRef node = tree.getNode(i);
if (node != tree.getRoot()) {
if (updateNeededForNode(tree, node)) {
double branchLength = tree.getBranchLength(node);
double latentProportion = getLatentProportion(tree, node);
assert(latentProportion < 1.0);
double reward = branchLength * latentProportion;
double density = getBranchRewardDensity(reward, branchLength);
branchLikelihoods[node.getNumber()] = Math.log(density);
}
logLike += branchLikelihoods[node.getNumber()];
// TODO More importantly, MH proposals on [0,1] may be missing a Jacobian for which we should adjust.
// TODO This is easy to test and we should do it when sampling appears to work.
}
}
clearUpdateAllBranches();
clearAllCategories();
return logLike;
}
private boolean updateNeededForNode(Tree tree, NodeRef node) {
if (USE_CACHING) {
return (updateCategory != null && updateCategory[branchCategoryProvider.getBranchCategory(tree, node)]) || updateBranch[node.getNumber()];
} else {
return true;
}
}
public double getBranchRewardDensity(double reward, double branchLength) {
if (series == null) {
series = createSeries();
}
int state = 0 * 2 + 0; // just start = end = 0 entry
// Reward is [0,1], and we want to track time in latent state (= 1).
// Therefore all nodes are in state 0
// double joint = series.computePdf(reward, branchLength)[state];
double joint = series.computePdf(reward, branchLength, 0, 0);
double marg = series.computeConditionalProbability(branchLength, 0, 0);
// TODO Overhead in creating double[] could be saved by changing signature to computePdf
return joint / marg; // conditional on ending state.
}
@Override
public void makeDirty() {
likelihoodKnown = false;
series = null;
setUpdateAllBranches();
}
@Override
public String getTraitName() {
return BranchRateModel.RATE;
}
@Override
public Intent getIntent() {
return Intent.BRANCH;
}
@Override
public TreeTrait getTreeTrait(final String key) {
if (key.equals(BranchRateModel.RATE)) {
return this;
} else if (latentStateProportions != null && key.equals(latentStateProportions.getTraitName())) {
return latentStateProportions;
} else if (branchCategoryProvider != null && key.equals(branchCategoryProvider.getTraitName())) {
return branchCategoryProvider;
} else {
throw new IllegalArgumentException("Unrecognised Tree Trait key, " + key);
}
}
@Override
public TreeTrait[] getTreeTraits() {
return new TreeTrait[]{this, latentStateProportions, branchCategoryProvider};
}
@Override
public Class getTraitClass() {
return Double.class;
}
@Override
public boolean getLoggable() {
return true;
}
@Override
public Double getTrait(final Tree tree, final NodeRef node) {
return getBranchRate(tree, node);
}
@Override
public String getTraitString(final Tree tree, final NodeRef node) {
return Double.toString(getBranchRate(tree, node));
}
static class Mode {
double pdf;
double reward;
Mode(double pdf, double reward) {
this.pdf = pdf;
this.reward = reward;
}
}
static Mode findMode(List<Double> values, List<Double> rewards) {
Mode find = new Mode(values.get(0), rewards.get(0));
for (int i = 1; i < values.size(); ++i) {
if (values.get(i) > find.pdf) {
find.pdf = values.get(i);
find.reward = rewards.get(i);
}
}
return find;
}
static double calculateExpectation(List<Double> pdfs, List<Double> rewards) {
double weight = 0.0;
double wsum = 0.0;
for (int i = 0; i < pdfs.size(); ++i) {
weight += pdfs.get(i);
wsum += rewards.get(i) * pdfs.get(i);
}
double mean = wsum / weight;
// System.err.println(wsum);
// System.err.println(weight);
// System.err.println(mean);
//
// System.exit(-1);
return wsum;
}
public static void main(String[] args) {
Parameter rate = new Parameter.Default(2.0);
Parameter prop = new Parameter.Default(0.5);
SericolaLatentStateBranchRateModel model = new SericolaLatentStateBranchRateModel(rate, prop);
for (double branchLength = 0.1; branchLength <= 10.0; branchLength += 0.1) {
List<Double> pdfs = new ArrayList<Double>();
List<Double> rewards = new ArrayList<Double>();
for (double reward = 0; reward <= branchLength; reward += 0.01 * branchLength) {
double value = model.getBranchRewardDensity(reward, branchLength);
// System.out.println(reward + "," + model.getBranchRewardDensity(reward, branchLength));
System.out.println();
rewards.add(reward);
pdfs.add(value);
}
Mode mode = findMode(pdfs, rewards);
// System.out.println();
System.out.println(branchLength // ", " + mode.reward //+ " " + mode.pdf
+ " " + (mode.reward / branchLength)
+ " " + (calculateExpectation(pdfs, rewards) / branchLength)
);
//System.out.println(model.getSeries());
}
}
public MarkovReward getSeries() {
if (series == null) {
series = createSeries();
}
return series;
}
}