/* * MarkovJumpsBeagleTreeLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.treelikelihood; import dr.evomodel.branchmodel.BranchModel; import dr.evomodel.siteratemodel.SiteRateModel; import dr.evomodel.substmodel.MarkovJumpsSubstitutionModel; import dr.evomodel.substmodel.SubstitutionModel; import dr.evomodel.substmodel.UniformizedSubstitutionModel; import dr.evolution.alignment.PatternList; import dr.evolution.datatype.DataType; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.tree.TreeTrait; import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.tree.TreeModel; import dr.evomodel.tipstatesmodel.TipStatesModel; import dr.inference.loggers.LogColumn; import dr.inference.loggers.NumberColumn; import dr.inference.markovjumps.MarkovJumpsRegisterAcceptor; import dr.inference.markovjumps.MarkovJumpsType; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.util.Citation; import dr.util.CommonCitations; import java.util.*; /** * @author Marc Suchard * @author Vladimir Minin * <p/> * A base class for implementing Markov chain-induced counting processes (markovjumps) in BEAST using BEAGLE * This work is supported by NSF grant 0856099 * <p/> * Minin VN and Suchard MA (2008) Counting labeled transitions in continous-time Markov models of evolution. * Journal of Mathematical Biology, 56, 391-412. */ public class MarkovJumpsBeagleTreeLikelihood extends AncestralStateBeagleTreeLikelihood implements MarkovJumpsRegisterAcceptor, MarkovJumpsTraitProvider { public MarkovJumpsBeagleTreeLikelihood(PatternList patternList, TreeModel treeModel, BranchModel branchModel, SiteRateModel siteRateModel, BranchRateModel branchRateModel, TipStatesModel tipStatesModel, boolean useAmbiguities, PartialsRescalingScheme scalingScheme, boolean delayScaling, Map<Set<String>, Parameter> partialsRestrictions, DataType dataType, String stateTag, boolean useMAP, boolean returnMarginalLikelihood, boolean useUniformization, boolean reportUnconditionedColumns, int nSimulants) { super(patternList, treeModel, branchModel, siteRateModel, branchRateModel, tipStatesModel, useAmbiguities, scalingScheme, delayScaling, partialsRestrictions, dataType, stateTag, useMAP, returnMarginalLikelihood); this.useUniformization = useUniformization; this.reportUnconditionedColumns = reportUnconditionedColumns; this.nSimulants = nSimulants; markovjumps = new ArrayList<MarkovJumpsSubstitutionModel>(); branchModelNumber = new ArrayList<Integer>(); registerParameter = new ArrayList<Parameter>(); jumpTag = new ArrayList<String>(); expectedJumps = new ArrayList<double[][]>(); // storedExpectedJumps = new ArrayList<double[][]>(); tmpProbabilities = new double[stateCount * stateCount * categoryCount]; condJumps = new double[categoryCount][stateCount * stateCount]; } public void addRegister(Parameter addRegisterParameter, MarkovJumpsType type, boolean scaleByTime) { if ((type == MarkovJumpsType.COUNTS && addRegisterParameter.getDimension() != stateCount * stateCount) || (type == MarkovJumpsType.REWARDS && addRegisterParameter.getDimension() != stateCount) ) { throw new RuntimeException("Register parameter of wrong dimension"); } addVariable(addRegisterParameter); final String tag = addRegisterParameter.getId(); for (int i = 0; i < substitutionModelDelegate.getSubstitutionModelCount(); ++i) { registerParameter.add(addRegisterParameter); MarkovJumpsSubstitutionModel mjModel; SubstitutionModel substitutionModel = substitutionModelDelegate.getSubstitutionModel(i); if (useUniformization) { mjModel = new UniformizedSubstitutionModel(substitutionModel, type, nSimulants); } else { if (type == MarkovJumpsType.HISTORY) { throw new RuntimeException("Can only report complete history using uniformization"); } mjModel = new MarkovJumpsSubstitutionModel(substitutionModel, type); } markovjumps.add(mjModel); branchModelNumber.add(i); addModel(mjModel); setupRegistration(numRegisters); String traitName; if (substitutionModelDelegate.getSubstitutionModelCount() == 1) { traitName = tag; } else { traitName = tag + i; } jumpTag.add(traitName); expectedJumps.add(new double[treeModel.getNodeCount()][patternCount]); // storedExpectedJumps.add(new double[treeModel.getNodeCount()][patternCount]); boolean[] oldScaleByTime = this.scaleByTime; int oldScaleByTimeLength = (oldScaleByTime == null ? 0 : oldScaleByTime.length); this.scaleByTime = new boolean[oldScaleByTimeLength + 1]; if (oldScaleByTimeLength > 0) { System.arraycopy(oldScaleByTime, 0, this.scaleByTime, 0, oldScaleByTimeLength); } this.scaleByTime[oldScaleByTimeLength] = scaleByTime; if (type != MarkovJumpsType.HISTORY) { TreeTrait.DA da = new TreeTrait.DA() { final int registerNumber = numRegisters; public String getTraitName() { return tag; } public Intent getIntent() { return Intent.BRANCH; } public double[] getTrait(Tree tree, NodeRef node) { return getMarkovJumpsForNodeAndRegister(tree, node, registerNumber); } }; treeTraits.addTrait(traitName + "_base", da); treeTraits.addTrait(addRegisterParameter.getId(), new TreeTrait.SumAcrossArrayD( new TreeTrait.SumOverTreeDA(da))); } else { if (histories == null) { histories = new String[treeModel.getNodeCount()][patternCount]; } else { throw new RuntimeException("Only one complete history per markovJumpTreeLikelihood is allowed"); } if (nSimulants > 1) { throw new RuntimeException("Only one simulant allowed when saving complete history"); } // Add total number of changes over tree trait TreeTrait da = new TreeTrait.DA() { final int registerNumber = numRegisters; public String getTraitName() { return tag; } public Intent getIntent() { return Intent.BRANCH; } public double[] getTrait(Tree tree, NodeRef node) { return getMarkovJumpsForNodeAndRegister(tree, node, registerNumber); } }; treeTraits.addTrait(addRegisterParameter.getId(), new TreeTrait.SumOverTreeDA(da)); historyRegisterNumber = numRegisters; // Record the complete history for this register ((UniformizedSubstitutionModel) mjModel).setSaveCompleteHistory(true); if (useCompactHistory && logHistory) { treeTraits.addTrait(ALL_HISTORY, new TreeTrait.SA() { public String getTraitName() { return ALL_HISTORY; } public Intent getIntent() { return Intent.BRANCH; } public boolean getFormatAsArray() { return true; } public String[] getTrait(Tree tree, NodeRef node) { List<String> events = new ArrayList<String>(); for (int i = 0; i < patternCount; i++) { String eventString = getHistoryForNode(tree, node, i); if (eventString != null && eventString.compareTo("{}") != 0) { eventString = eventString.substring(1, eventString.length() - 1); if (eventString.contains("},{")) { // There are multiple events String[] elements = eventString.split("(?<=\\}),(?=\\{)"); for (String e : elements) { events.add(e); } } else { events.add(eventString); } } } String[] array = new String[events.size()]; events.toArray(array); return array; } public boolean getLoggable() { return true; } }); } for (int site = 0; site < patternCount; ++site) { final String anonName = (patternCount == 1) ? HISTORY : HISTORY + "_" + (site + 1); final int anonSite = site; treeTraits.addTrait(anonName, new TreeTrait.S() { public String getTraitName() { return anonName; } public Intent getIntent() { return Intent.BRANCH; } public String getTrait(Tree tree, NodeRef node) { String history = getHistoryForNode(tree, node, anonSite); return (history.compareTo("{}") != 0) ? history : null; // Return null if empty } public boolean getLoggable() { return logHistory && !useCompactHistory; } }); } } numRegisters++; } // End of loop over branch models } public void setLogHistories(boolean in) { logHistory = in; } public void setUseCompactHistory(boolean in) { useCompactHistory = in; } // public double[] getRewardsForNodeAndPattern(Tree tree, NodeRef node, int pattern) { // double[] rtn = new double[numRegisters]; // for (int r = 0; r < numRegisters; r++) { // double[] mjs = getMarkovJumpsForNodeAndRegister(tree, node, r); // rtn[r] = mjs[pattern]; // } // return rtn; // } public double[] getMarkovJumpsForNodeAndRegister(Tree tree, NodeRef node, int whichRegister) { return getMarkovJumpsForRegister(tree, whichRegister)[node.getNumber()]; } // public double[][] getMarkovJumpsForNode(Tree tree, NodeRef node) { // double[][] rtn = new double[numRegisters][]; // for (int r = 0; r < numRegisters; r++) { // rtn[r] = getMarkovJumpsForNodeAndRegister(tree, node, r); // } // return rtn; // } private void refresh(Tree tree) { if (tree != treeModel) { throw new RuntimeException("Must call with internal tree"); } if (!likelihoodKnown) { calculateLogLikelihood(); likelihoodKnown = true; } if (!areStatesRedrawn) { redrawAncestralStates(); } } public double[][] getMarkovJumpsForRegister(Tree tree, int whichRegister) { refresh(tree); return expectedJumps.get(whichRegister); } public String getHistoryForNode(Tree tree, NodeRef node, int site) { return getHistory(tree)[node.getNumber()][site]; } public String[][] getHistory(Tree tree) { refresh(tree); return histories; } // private static String formattedValue(double[] values) { // double total = 0; // for (double summant : values) { // total += summant; // } // return Double.toString(total); // Currently return the sum across sites // } private void setupRegistration(int whichRegistration) { double[] registration = registerParameter.get(whichRegistration).getParameterValues(); markovjumps.get(whichRegistration).setRegistration(registration); areStatesRedrawn = false; } protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { for (int r = 0; r < numRegisters; r++) { if (variable == registerParameter.get(r)) { setupRegistration(r); return; } } super.handleVariableChangedEvent(variable, index, type); } protected void hookCalculation(Tree tree, NodeRef parentNode, NodeRef childNode, int[] parentStates, int[] childStates, double[] inProbabilities, int[] rateCategory) { final int childNum = childNode.getNumber(); double[] probabilities = inProbabilities; if (probabilities == null) { // Leaf will call this hook with a null getMatrix(childNum, tmpProbabilities); probabilities = tmpProbabilities; } final double branchRate = branchRateModel.getBranchRate(tree, childNode); final double parentTime = tree.getNodeHeight(parentNode); final double childTime = tree.getNodeHeight(childNode); final double substTime = parentTime - childTime; for (int r = 0; r < markovjumps.size(); r++) { MarkovJumpsSubstitutionModel thisMarkovJumps = markovjumps.get(r); final int modelNumberFromrRegistry = branchModelNumber.get(r); // int dummy = 0; // final int modelNumberFromTree = branchSubstitutionModel.getBranchIndex(tree, childNode, dummy); // @todo AR - not sure about this - if this is an epoch this is just going to get the most // @todo tipward model for the branch. I think this was what was happening before (in comment, // @todo above). BranchModel.Mapping mapping = branchModel.getBranchModelMapping(childNode); if (modelNumberFromrRegistry == mapping.getOrder()[0]) { if (useUniformization) { computeSampledMarkovJumpsForBranch(((UniformizedSubstitutionModel) thisMarkovJumps), substTime, branchRate, childNum, parentStates, childStates, parentTime, childTime, probabilities, scaleByTime[r], expectedJumps.get(r), rateCategory, r == historyRegisterNumber); } else { computeIntegratedMarkovJumpsForBranch(thisMarkovJumps, substTime, branchRate, childNum, parentStates, childStates, probabilities, condJumps, scaleByTime[r], expectedJumps.get(r), rateCategory); } } else { // Fill with zeros double[] result = expectedJumps.get(r)[childNum]; Arrays.fill(result, 0.0); } } } private void computeSampledMarkovJumpsForBranch(UniformizedSubstitutionModel thisMarkovJumps, double substTime, double branchRate, int childNum, int[] parentStates, int[] childStates, double parentTime, double childTime, double[] probabilities, boolean scaleByTime, double[][] thisExpectedJumps, int[] rateCategory, boolean saveHistory) { // Fill condJumps with sampled values for this branch for each site for (int j = 0; j < patternCount; j++) { final int category = rateCategory == null ? 0 : rateCategory[j]; final double categoryRate = siteRateModel.getRateForCategory(category); final int matrixIndex = category * stateCount * stateCount; double value = thisMarkovJumps.computeCondStatMarkovJumps( parentStates[j], childStates[j], substTime * branchRate * categoryRate, probabilities[matrixIndex + parentStates[j] * stateCount + childStates[j]] ); if (scaleByTime) { value /= branchRate * categoryRate; } thisExpectedJumps[childNum][j] = value; if (saveHistory) { int site = (useCompactHistory) ? j + 1 : -1; histories[childNum][j] = thisMarkovJumps.getCompleteHistory(site, parentTime, childTime); } } } private void computeIntegratedMarkovJumpsForBranch(MarkovJumpsSubstitutionModel thisMarkovJumps, double substTime, double branchRate, int childNum, int[] parentStates, int[] childStates, double[] probabilities, double[][] condJumps, boolean scaleByTime, double[][] thisExpectedJumps, int[] rateCategory) { // Fill condJumps with conditional mean values for this branch for (int i = 0; i < categoryCount; i++) { double rate = siteRateModel.getRateForCategory(i); if (rate > 0) { if (categoryCount == 1) { thisMarkovJumps.computeCondStatMarkovJumps( substTime * branchRate * rate, probabilities, condJumps[i]); } else { System.arraycopy(probabilities, i * stateCount * stateCount, tmpProbabilities, 0, stateCount * stateCount); thisMarkovJumps.computeCondStatMarkovJumps( substTime * branchRate * rate, tmpProbabilities, condJumps[i]); } if (scaleByTime) { double scalar = branchRate * rate; for (int j = 0; j < condJumps[i].length; j++) { condJumps[i][j] /= scalar; } } } else { Arrays.fill(condJumps[i], 0.0); if (thisMarkovJumps.getType() == MarkovJumpsType.REWARDS && scaleByTime) { for (int j = 0; j < stateCount; j++) { condJumps[i][j * stateCount + j] = substTime; } } } } for (int j = 0; j < patternCount; j++) { // Pick out values given parent and child states int category = rateCategory == null ? 0 : rateCategory[j]; thisExpectedJumps[childNum][j] = condJumps[category][parentStates[j] * stateCount + childStates[j]]; } } // public void storeState() { // // super.storeState(); // //// if (areStatesRedrawn) { // for (int i = 0; i < expectedJumps.size(); i++) { // double[][] thisExpectedJumps = expectedJumps.get(i); // double[][] storedThisExpectedJumps = storedExpectedJumps.get(i); // for (int j = 0; j < thisExpectedJumps.length; j++) { // System.arraycopy(thisExpectedJumps[j], 0, storedThisExpectedJumps[j], 0, // thisExpectedJumps[j].length); // } // } //// } // } // // public void restoreState() { // // super.restoreState(); // // List<double[][]> tmp = expectedJumps; // expectedJumps = storedExpectedJumps; // storedExpectedJumps = tmp; // areStatesRedrawn = false; // // } public LogColumn[] getColumns() { int nColumns = patternCount * numRegisters; if (reportUnconditionedColumns) { if (categoryCount == 1) { nColumns += numRegisters; } else { nColumns *= 2; } } int index = 0; LogColumn[] allColumns = new LogColumn[nColumns]; for (int r = 0; r < numRegisters; r++) { for (int j = 0; j < patternCount; j++) { allColumns[index++] = new ConditionedCountColumn(jumpTag.get(r), r, j); if (reportUnconditionedColumns) { if (categoryCount > 1) { allColumns[index++] = new UnconditionedCountColumn(jumpTag.get(r), r, j, rateCategory); } } } if (reportUnconditionedColumns) { if (categoryCount == 1) { allColumns[index++] = new UnconditionedCountColumn(jumpTag.get(r), r); } } } return allColumns; } protected abstract class CountColumn extends NumberColumn { protected int indexRegistration; protected int indexSite; public CountColumn(String label, int r, int j) { super(label + (j >= 0 ? "[" + (j + 1) + "]" : "")); indexRegistration = r; indexSite = j; } public abstract double getDoubleValue(); } protected class ConditionedCountColumn extends CountColumn { public ConditionedCountColumn(String label, int r, int j) { super("c_" + label, r, j); } public double getDoubleValue() { double total = 0; double[][] values = getMarkovJumpsForRegister(treeModel, indexRegistration); for (int i = 0; i < treeModel.getNodeCount(); i++) { total += values[i][indexSite]; } return total; } } protected class UnconditionedCountColumn extends CountColumn { int[] rateCategory; public UnconditionedCountColumn(String label, int r, int j, int[] rateCategory) { super("u_" + label, r, j); this.rateCategory = rateCategory; } public UnconditionedCountColumn(String label, int r) { this(label, r, -1, null); } public double getDoubleValue() { double value = markovjumps.get(indexRegistration).getMarginalRate() * getExpectedTreeLength(); if (rateCategory != null) { value *= siteRateModel.getRateForCategory(rateCategory[indexSite]); } return value; } private double getExpectedTreeLength() { double expectedTreeLength = 0; for (int i = 0; i < treeModel.getNodeCount(); i++) { NodeRef node = treeModel.getNode(i); if (!treeModel.isRoot(node)) { expectedTreeLength += branchRateModel.getBranchRate(treeModel, node) * treeModel.getBranchLength(node); } } return expectedTreeLength; } } @Override public String getDescription() { return super.getDescription() + " (first citation) with MarkovJumps inference techniques (second citation)"; } public List<Citation> getCitations() { List<Citation> citationList = new ArrayList<Citation>(super.getCitations()); citationList.add(CommonCitations.MININ_2008_COUNTING); return citationList; } public static final String ALL_HISTORY = "history_all"; public static final String HISTORY = "history"; public static final String TOTAL_COUNTS = "allTransitions"; private List<MarkovJumpsSubstitutionModel> markovjumps; private List<Integer> branchModelNumber; private List<Parameter> registerParameter; private List<String> jumpTag; private List<double[][]> expectedJumps; // private List<double[][]> storedExpectedJumps; private boolean logHistory = false; private boolean useCompactHistory = false; private String[][] histories = null; private boolean[] scaleByTime; private double[] tmpProbabilities; private double[][] condJumps; private int numRegisters; private int historyRegisterNumber = -1; private final boolean useUniformization; private final int nSimulants; private final boolean reportUnconditionedColumns; }