/* * CountableBranchCategoryProvider.java * * Copyright (c) 2002-2016 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.branchratemodel; import dr.evolution.tree.TreeUtils; import dr.evomodel.treelikelihood.MarkovJumpsTraitProvider; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.tree.TreeTrait; import dr.evolution.util.TaxonList; import dr.evomodel.tree.TreeModel; import dr.evomodel.tree.TreeParameterModel; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.math.MathUtils; import java.util.ArrayList; import java.util.List; import java.util.Set; /** * @author Marc A. Suchard * @author Philippe Lemey * @author Andrew Rambaut */ public interface CountableBranchCategoryProvider extends TreeTrait<Double> { public int getBranchCategory(final Tree tree, final NodeRef node); public void setCategoryCount(final int count); public int getCategoryCount(); public class SingleBranchCategoryModel implements CountableBranchCategoryProvider { @Override public int getBranchCategory(final Tree tree, final NodeRef node) { return 0; } @Override public void setCategoryCount(final int count) { // Do nothing } @Override public int getCategoryCount() { return 1; } @Override public String getTraitName() { return "categories"; } @Override public Intent getIntent() { return Intent.BRANCH; } @Override public Class getTraitClass() { return Integer.class; } @Override public Double getTrait(Tree tree, NodeRef node) { return 1.0; } @Override public String getTraitString(Tree tree, NodeRef node) { return "1"; } @Override public boolean getLoggable() { return true; } } public abstract class BranchCategoryModel extends TreeParameterModel implements CountableBranchCategoryProvider { public BranchCategoryModel(TreeModel tree, Parameter parameter) { super(tree, parameter, false, Intent.BRANCH); this.categoryParameter = parameter; // for (int i = 0; i < parameter.getDimension(); ++i) { // categoryParameter.setParameterValue(i, 0.0); // } this.treeModel = tree; categoryCount = 1; } // public BranchCategoryModel(TreeModel tree, Parameter parameter, boolean resetCategories) { // super(tree, parameter, false, Intent.BRANCH); // // this.categoryParameter = parameter; // this.treeModel = tree; // } public void setCategoryCount(final int count) { categoryCount = count; Parameter.DefaultBounds bound = new Parameter.DefaultBounds(categoryCount - 1, 0, categoryParameter.getDimension()); categoryParameter.addBounds(bound); for (int i = 0; i < categoryParameter.getDimension(); ++i) { if (categoryParameter.getParameterValue(i) >= categoryCount) { categoryParameter.setParameterValue(i, categoryCount - 1); } } } @Override public int getBranchCategory(final Tree tree, final NodeRef node) { return (int) Math.round(getNodeValue(tree, node)); } @Override public int getCategoryCount() { return categoryCount; } protected int categoryCount; protected final Parameter categoryParameter; protected final TreeModel treeModel; } public class IndependentBranchCategoryModel extends BranchCategoryModel { public IndependentBranchCategoryModel(TreeModel tree, Parameter parameter) { super(tree, parameter); } public void randomize() { for (NodeRef node : treeModel.getNodes()) { if (node != treeModel.getRoot()) { int index = MathUtils.nextInt(categoryCount); setNodeValue(treeModel, node, index); } } } } public class MarkovJumpBranchCategoryModel extends BranchCategoryModel { public MarkovJumpBranchCategoryModel(MarkovJumpsTraitProvider markovJumpTrait, Parameter parameter) { super(markovJumpTrait.getTreeModel(), parameter); } @Override public int getBranchCategory(final Tree tree, final NodeRef node) { synchronized (this) { if (traitsChanged) { updateTraitRateCategories(); traitsChanged = false; } } return super.getBranchCategory(tree, node); } private void updateTraitRateCategories() { } public void handleModelChangedEvent(Model model, Object object, int index) { if (model == treeModel) { traitsChanged = true; fireModelChanged(); } else { throw new IllegalArgumentException("Unknown model component!"); } } private boolean traitsChanged = true; } public class CladeBranchCategoryModel extends BranchCategoryModel { public CladeBranchCategoryModel(TreeModel tree, Parameter parameter) { super(tree, parameter); } public void handleModelChangedEvent(Model model, Object object, int index) { if (model == treeModel) { cladesChanged = true; fireModelChanged(); } else { throw new IllegalArgumentException("Unknown model component!"); } } private void recurseDownClade(final NodeRef node, final TreeModel treeModel, final CladeContainer clade, boolean include) { if (include && !treeModel.isRoot(node)) { setNodeValue(treeModel, node, clade.getRateCategory()); } if (!treeModel.isExternal(node)) { for (int i = 0; i < tree.getChildCount(node); i++) { NodeRef child = tree.getChild(node, i); recurseDownClade(child, treeModel, clade, true); } } } private void updateCladeRateCategories() { if (leafSetList != null) { // Set all to zero for (NodeRef node : treeModel.getNodes()) { if (node != treeModel.getRoot()) { setNodeValue(treeModel, node, 0.0); } } // Handle clades for (CladeContainer clade : leafSetList) { NodeRef node = TreeUtils.getCommonAncestorNode(treeModel, clade.getLeafSet()); if (node != treeModel.getRoot()) { if (clade.getIncludeStem()) { setNodeValue(treeModel, node, clade.getRateCategory()); } } // Include the clade below if (!clade.getExcludeClade()) { recurseDownClade(node, treeModel, clade, clade.getIncludeStem()); } } } if (trunkSetList != null) { //we keep the default rates assignments by clade definitions if they exist (leafSetList != null), if they do not exist, set default to 0.0 if (leafSetList == null) { for (NodeRef node : treeModel.getNodes()) { if (node != treeModel.getRoot()) { setNodeValue(treeModel, node, 0.0); } } } // currently, specific backbone rates will overwrite branch assignments by clade definitions //TODO: think about turning this around. One can imagine setting backbone rates and then additional rates based on clade definitions for (CladeContainer trunk : trunkSetList) { for (NodeRef node : treeModel.getNodes()) { if (onAncestralPath(treeModel, node, trunk.getLeafSet())) { if (node != treeModel.getRoot()) { setNodeValue(treeModel, node, trunk.getRateCategory()); } } } } } } private boolean onAncestralPath(Tree tree, NodeRef node, Set targetSet) { if (tree.isExternal(node)) return false; Set leafSet = TreeUtils.getDescendantLeaves(tree, node); int size = leafSet.size(); leafSet.retainAll(targetSet); if (leafSet.size() > 0) { // if all leaves below are in target then check just above. if (leafSet.size() == size) { Set superLeafSet = TreeUtils.getDescendantLeaves(tree, tree.getParent(node)); superLeafSet.removeAll(targetSet); // the branch is on ancestral path if the super tree has some non-targets in it return (superLeafSet.size() > 0); } else return true; } else return false; } public void setClade(TaxonList taxonList, int rateCategory, boolean includeStem, boolean excludeClade, boolean trunk) throws TreeUtils.MissingTaxonException { Set<String> leafSet = TreeUtils.getLeavesForTaxa(treeModel, taxonList); if (!trunk) { if (leafSetList == null) { leafSetList = new ArrayList<CladeContainer>(); } leafSetList.add(new CladeContainer(leafSet, rateCategory, includeStem, excludeClade)); cladesChanged = true; } else { if (trunkSetList == null) { trunkSetList = new ArrayList<CladeContainer>(); } trunkSetList.add(new CladeContainer(leafSet, rateCategory, includeStem, excludeClade)); cladesChanged = true; } if (rateCategory >= categoryCount) { categoryCount = rateCategory + 1; } } @Override public int getBranchCategory(final Tree tree, final NodeRef node) { synchronized (this) { if (cladesChanged) { updateCladeRateCategories(); cladesChanged = false; } } return super.getBranchCategory(tree, node); } private class CladeContainer { private Set<String> leafSet; private int rateCategory; boolean includeStem; boolean excludeClade; public CladeContainer(Set<String> leafSet, int rateCategory, boolean includeStem, boolean excludeClade) { this.leafSet = leafSet; this.rateCategory = rateCategory; this.includeStem = includeStem; this.excludeClade = excludeClade; } public Set<String> getLeafSet() { return leafSet; } public int getRateCategory() { return rateCategory; } public boolean getIncludeStem() { return includeStem; } public boolean getExcludeClade() { return excludeClade; } } private boolean cladesChanged = false; private List<CladeContainer> leafSetList = null; private List<CladeContainer> trunkSetList = null; } }