/*
* CountableMixtureBranchRates.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.branchratemodel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.tree.TreeModel;
import dr.evomodelxml.branchratemodel.CountableMixtureBranchRatesParser;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import java.util.List;
/**
* @author Marc Suchard
* @author Philippe Lemey
*/
public class CountableMixtureBranchRates extends AbstractBranchRateModel implements Loggable {
private final Parameter ratesParameter;
private final TreeModel treeModel;
private final List<AbstractBranchRateModel> randomEffectsModels;
private final int categoryCount;
public CountableMixtureBranchRates(CountableBranchCategoryProvider rateCategories,
TreeModel treeModel, Parameter ratesParameter,
List<AbstractBranchRateModel> randomEffects, boolean inLogSpace) {
super(CountableMixtureBranchRatesParser.COUNTABLE_CLOCK_BRANCH_RATES);
this.treeModel = treeModel;
categoryCount = ratesParameter.getDimension();
this.rateCategories = rateCategories;
rateCategories.setCategoryCount(categoryCount);
if (rateCategories instanceof Model) {
addModel((Model)rateCategories);
}
this.ratesParameter = ratesParameter;
addVariable(ratesParameter);
// Handle random effects
this.randomEffectsModels = randomEffects;
if (randomEffectsModels != null) {
for (AbstractBranchRateModel model : randomEffectsModels)
addModel(model);
}
// TODO Check that randomEffectsModel means are zero
modelInLogSpace = inLogSpace;
helper.addTrait(this);
helper.addTrait(new TreeTrait.I() {
@Override
public String getTraitName() {
return getCategoryTraitName();
}
@Override
public Intent getIntent() {
return Intent.BRANCH;
}
@Override
public Integer getTrait(Tree tree, NodeRef node) {
return getBranchCategory(tree, node);
}
});
helper.addTrait(new TreeTrait.D() {
@Override
public String getTraitName() {
return getCategoryEffectTraitName();
}
@Override
public Intent getIntent() {
return Intent.BRANCH;
}
@Override
public Double getTrait(Tree tree, NodeRef node) {
return getBranchCategoryEffect(tree, node);
}
});
helper.addTrait(new TreeTrait.D() {
@Override
public String getTraitName() {
return getCategoryRateTraitName();
}
@Override
public Intent getIntent() {
return Intent.BRANCH;
}
@Override
public Double getTrait(Tree tree, NodeRef node) {
return getBranchCategoryRate(tree, node);
}
});
helper.addTrait(new TreeTrait.D() {
@Override
public String getTraitName() {
return getRandomEffectTraitName();
}
@Override
public Intent getIntent() {
return Intent.BRANCH;
}
@Override
public Double getTrait(Tree tree, NodeRef node) {
return getBranchRandomEffect(tree, node);
}
});
}
private String getCategoryTraitName() {
return getTraitName() + ".category";
}
private String getCategoryEffectTraitName() {
return getTraitName() + ".category.effect";
}
private String getCategoryRateTraitName() {
return getTraitName() + ".category.rate";
}
private String getRandomEffectTraitName() {
return getTraitName() + ".random.effect";
}
private int getBranchCategory(Tree tree, NodeRef node) {
return rateCategories.getBranchCategory(tree, node);
}
private double getBranchCategoryRate(Tree tree, NodeRef node) {
if (modelInLogSpace) {
return ratesParameter.getParameterValue(getBranchCategory(tree, node));
} else {
return Math.exp(ratesParameter.getParameterValue(getBranchCategory(tree, node)));
}
}
private double getBranchCategoryEffect(Tree tree, NodeRef node) {
if (modelInLogSpace) {
return (getBranchCategoryRate(tree, node) - ratesParameter.getParameterValue(0));
} else {
return (getBranchCategoryRate(tree, node) / ratesParameter.getParameterValue(0));
}
}
private double getBranchRandomEffect(Tree tree, NodeRef node) {
double effect;
if (modelInLogSpace) {
effect = 0;
} else {
effect = 1;
}
if (randomEffectsModels != null) {
for (AbstractBranchRateModel model : randomEffectsModels) {
if (modelInLogSpace) {
effect += model.getBranchRate(tree, node);
} else {
effect *= model.getBranchRate(tree, node);
}
}
}
return effect;
}
public TreeTrait[] getTreeTraits() {
return helper.getTreeTraits();
}
public TreeTrait getTreeTrait(String key) {
return helper.getTreeTrait(key);
}
public double getLogLikelihood() {
double logLike = 0.0;
if (randomEffectsModels != null) {
for (AbstractBranchRateModel model : randomEffectsModels) {
logLike += model.getLogLikelihood();
}
}
return logLike;
}
void test() {
getTrait(null, null);
}
public LogColumn[] getColumns() {
LogColumn[] columns = new LogColumn[ratesParameter.getDimension()];
for (int i = 0; i < ratesParameter.getDimension(); ++i) {
columns[i] = new OccupancyColumn(i);
}
return columns;
}
private class OccupancyColumn extends NumberColumn {
private final int index;
public OccupancyColumn(int index) {
super("Occupancy");
this.index = index;
}
public double getDoubleValue() {
int occupancy = 0;
for (NodeRef node : treeModel.getNodes()) {
if (node != treeModel.getRoot()) {
if (rateCategories.getBranchCategory(treeModel, node) == index) {
occupancy++;
}
}
}
return occupancy;
}
}
public void handleModelChangedEvent(Model model, Object object, int index) {
if (model == rateCategories) {
fireModelChanged();
} else {
AbstractBranchRateModel foundModel = findRandomEffectsModel(model);
if (foundModel != null) {
if (object == model) {
fireModelChanged();
} else if (object == null) {
fireModelChanged(null, index);
} else {
throw new IllegalArgumentException("Unknown object component!");
}
} else {
throw new IllegalArgumentException("Unknown model component!");
}
}
}
private AbstractBranchRateModel findRandomEffectsModel(Model model) {
AbstractBranchRateModel found = null;
int index = randomEffectsModels.indexOf(model);
if (index != -1) {
found = randomEffectsModels.get(index);
}
return found;
}
protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
fireModelChanged();
}
protected void storeState() {
// nothing to do
}
protected void restoreState() {
// nothing to do
}
protected void acceptState() {
// nothing to do
}
public double getBranchRate(final Tree tree, final NodeRef node) {
assert !tree.isRoot(node) : "root node doesn't have a rate!";
int rateCategory = rateCategories.getBranchCategory(tree, node);
double effect = ratesParameter.getParameterValue(rateCategory);
if (randomEffectsModels != null) {
for (AbstractBranchRateModel model : randomEffectsModels) {
if (modelInLogSpace) {
effect += model.getBranchRate(tree, node);
} else {
effect *= model.getBranchRate(tree, node);
}
}
}
if (modelInLogSpace) {
effect = Math.exp(effect);
}
return effect;
}
private final Helper helper = new Helper();
private final CountableBranchCategoryProvider rateCategories;
private final boolean modelInLogSpace;
}