/*
* BranchSpecificBranchModel.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.branchmodel;
import dr.evolution.tree.TreeUtils;
import dr.evomodel.substmodel.FrequencyModel;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.TaxonList;
import dr.evomodel.tree.TreeModel;
import dr.evomodelxml.branchratemodel.LocalClockModelParser;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import java.util.*;
/**
* A branch model which allows different clades (defined by MRCAs of taxon lists) to have different
* substitution models.
*
* @author Andrew Rambaut
* @version $Id$
*/
public class BranchSpecificBranchModel extends AbstractModel implements BranchModel {
private TreeModel treeModel;
protected Map<BitSet, Clade> clades = new HashMap<BitSet, Clade>();
private boolean updateNodeMaps = true;
private Map<NodeRef, Mapping> nodeMap = new HashMap<NodeRef, Mapping>();
private Map<NodeRef, Mapping> externalNodeMap = new HashMap<NodeRef, Mapping>();
private final SubstitutionModel rootSubstitutionModel;
private final List<SubstitutionModel> substitutionModels = new ArrayList<SubstitutionModel>();
private boolean requiresMatrixConvolution = false;
public BranchSpecificBranchModel(TreeModel treeModel, SubstitutionModel rootSubstitutionModel) {
super(LocalClockModelParser.LOCAL_CLOCK_MODEL);
this.treeModel = treeModel;
addModel(treeModel);
this.rootSubstitutionModel = rootSubstitutionModel;
addModel(rootSubstitutionModel);
substitutionModels.add(rootSubstitutionModel);
}
/**
* Adds a substitution model specific to a clade.
* @param taxonList a list of taxa who's MRCA define the clade
* @param substitutionModel the substitution model
* @param stemWeight the proportion of the stem branch to include in this model (0, 1)
* @throws TreeUtils.MissingTaxonException
*/
public void addClade(TaxonList taxonList, SubstitutionModel substitutionModel, double stemWeight) throws TreeUtils.MissingTaxonException {
int index = substitutionModels.indexOf(substitutionModel);
if (index == -1) {
index = substitutionModels.size();
substitutionModels.add(substitutionModel);
addModel(substitutionModel);
}
BitSet tips = TreeUtils.getTipsBitSetForTaxa(treeModel, taxonList);
Clade clade = new Clade(index, tips, stemWeight);
clades.put(tips, clade);
if (stemWeight > 0.0 || stemWeight < 1.0) {
requiresMatrixConvolution = true;
}
}
public void addExternalBranches(TaxonList taxonList, SubstitutionModel substitutionModel) throws TreeUtils.MissingTaxonException {
int x = substitutionModels.indexOf(substitutionModel);
if (x == -1) {
x = substitutionModels.size();
substitutionModels.add(substitutionModel);
addModel(substitutionModel);
}
final int index = x;
for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
NodeRef node = treeModel.getExternalNode(i);
externalNodeMap.put(node, new Mapping() {
// @Override
public int[] getOrder() {
return new int[]{index};
}
// @Override
public double[] getWeights() {
return new double[]{1.0};
}
});
}
}
public void addBackbone(TaxonList taxonList, SubstitutionModel substitutionModel) throws TreeUtils.MissingTaxonException {
throw new UnsupportedOperationException("Not implemented yet");
}
@Override
public Mapping getBranchModelMapping(NodeRef branch) {
if (updateNodeMaps) {
setupNodeMaps();
}
Mapping mapping = externalNodeMap.get(branch);
if (mapping != null) {
return mapping;
}
mapping = nodeMap.get(branch);
if (mapping != null) {
return mapping;
}
return BranchModel.DEFAULT;
}
// @Override
public List<SubstitutionModel> getSubstitutionModels() {
return substitutionModels;
}
// @Override
public SubstitutionModel getRootSubstitutionModel() {
return rootSubstitutionModel;
}
public FrequencyModel getRootFrequencyModel() {
return getRootSubstitutionModel().getFrequencyModel();
}
// @Override
public boolean requiresMatrixConvolution() {
return requiresMatrixConvolution;
}
public void handleModelChangedEvent(Model model, Object object, int index) {
if (model == treeModel && clades.size() > 0) {
updateNodeMaps = true;
}
fireModelChanged();
}
protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
}
protected void storeState() {
}
protected void restoreState() {
if (clades.size() > 0) {
updateNodeMaps = true;
}
}
protected void acceptState() {
}
private void setupNodeMaps() {
if (clades.size() > 0) {
setupNodeMaps(treeModel, treeModel.getRoot(), new BitSet());
}
updateNodeMaps = false;
}
private void setupNodeMaps(Tree tree, NodeRef node, BitSet tips) {
Clade clade;
if (tree.isExternal(node)) {
tips.set(node.getNumber());
clade = null;
} else {
for (int i = 0; i < tree.getChildCount(node); i++) {
NodeRef child = tree.getChild(node, i);
BitSet childTips = new BitSet();
setupNodeMaps(tree, child, childTips);
tips.or(childTips);
}
clade = clades.get(tips);
}
if (clade != null) {
for (int i = 0; i < tree.getChildCount(node); i++) {
NodeRef child = tree.getChild(node, i);
setNodeMap(tree, child, clade);
}
final double weight = clade.getStemWeight();
if (weight > 0.0) {
final int ancestralIndex;
final int index = clade.getIndex();
Mapping ancestoralMapping = nodeMap.get(node);
if (ancestoralMapping != null) {
ancestralIndex = ancestoralMapping.getOrder()[0];
} else {
ancestralIndex = 0;
}
nodeMap.put(node, new Mapping() {
// @Override
public int[] getOrder() {
return new int[] { index , ancestralIndex };
}
// @Override
public double[] getWeights() {
return new double[] { weight, 1.0 - weight };
}
});
}
}
}
private void setNodeMap(Tree tree, NodeRef node, final Clade clade) {
if (!tree.isExternal(node)) {
for (int i = 0; i < tree.getChildCount(node); i++) {
NodeRef child = tree.getChild(node, i);
setNodeMap(tree, child, clade);
}
}
nodeMap.put(node, new Mapping() {
// @Override
public int[] getOrder() {
return new int[] { clade.getIndex() };
}
// @Override
public double[] getWeights() {
return new double[] { 1.0 };
}
});
}
private class Clade {
Clade(int index, BitSet tips, double stemWeight) {
this.index = index;
this.tips = tips;
this.stemWeight = stemWeight;
}
public int getIndex() {
return index;
}
public BitSet getTips() {
return tips;
}
public double getStemWeight() {
return stemWeight;
}
private final int index;
private final BitSet tips;
private final double stemWeight;
}
}