/* * AncestralStateBeagleTreeLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.treelikelihood; import dr.evomodel.branchmodel.BranchModel; import dr.evomodel.siteratemodel.SiteRateModel; import dr.evolution.alignment.PatternList; import dr.evolution.alignment.UncertainSiteList; import dr.evolution.datatype.Codons; import dr.evolution.datatype.DataType; import dr.evolution.datatype.GeneralDataType; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.tree.TreeTrait; import dr.evolution.tree.TreeTraitProvider; import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.tree.TreeModel; import dr.evomodel.tipstatesmodel.TipStatesModel; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.math.MathUtils; import java.util.Map; import java.util.Set; /** * @author Marc Suchard * @author Andrew Rambaut */ public class AncestralStateBeagleTreeLikelihood extends BeagleTreeLikelihood implements TreeTraitProvider, AncestralStateTraitProvider { // public AncestralStateBeagleTreeLikelihood(PatternList patternList, TreeModel treeModel, // BranchSubstitutionModel branchSubstitutionModel, SiteRateModel siteRateModel, // BranchRateModel branchRateModel, boolean useAmbiguities, // PartialsRescalingScheme scalingScheme, // DataType dataType, // String tag, // SubstitutionModel substModel) { // this(patternList, treeModel, branchSubstitutionModel, siteRateModel, branchRateModel, useAmbiguities, scalingScheme, // dataType, tag, substModel, false, true); // } public AncestralStateBeagleTreeLikelihood(PatternList patternList, TreeModel treeModel, BranchModel branchModel, SiteRateModel siteRateModel, BranchRateModel branchRateModel, TipStatesModel tipStatesModel, boolean useAmbiguities, PartialsRescalingScheme scalingScheme, boolean delayRescalingUntilUnderflow, Map<Set<String>, Parameter> partialsRestrictions, final DataType dataType, final String tag, // SubstitutionModel substModel, boolean useMAP, boolean returnML) { super(patternList, treeModel, branchModel, siteRateModel, branchRateModel, tipStatesModel, useAmbiguities, scalingScheme, delayRescalingUntilUnderflow, partialsRestrictions); this.dataType = dataType; // this.tag = tag; probabilities = new double[stateCount * stateCount * categoryCount]; partials = new double[stateCount * patternCount * categoryCount]; // rootPartials = new double[stateCount*patternCount]; // cumulativeScaleBuffers = new int[nodeCount][]; // scaleBufferIndex = getScaleBufferCount() - 1; // Save tip states locally so these do not need to be transfers back if (useAmbiguities()) { tipPartials = new double[tipCount][]; } else { tipStates = new int[tipCount][]; } for (int i = 0; i < tipCount; i++) { // Find the id of tip i in the patternList String id = treeModel.getTaxonId(i); int index = patternList.getTaxonIndex(id); if (useAmbiguities()) { tipPartials[i] = getPartials(patternList, index); } else { tipStates[i] = getStates(patternList, index); } } reconstructedStates = new int[treeModel.getNodeCount()][patternCount]; storedReconstructedStates = new int[treeModel.getNodeCount()][patternCount]; this.useMAP = useMAP; this.returnMarginalLogLikelihood = returnML; treeTraits.addTrait(new TreeTrait.IA() { public String getTraitName() { return tag; } public Intent getIntent() { return Intent.NODE; } public Class getTraitClass() { return int[].class; } public int[] getTrait(Tree tree, NodeRef node) { return getStatesForNode(tree, node); } public String getTraitString(Tree tree, NodeRef node) { return formattedState(getStatesForNode(tree, node), dataType); } }); } private double[] getPartials(PatternList patternList, int sequenceIndex) { double[] partials = new double[patternCount * stateCount]; boolean[] stateSet; int v = 0; for (int i = 0; i < patternCount; i++) { if (patternList instanceof UncertainSiteList) { ((UncertainSiteList) patternList).fillPartials(sequenceIndex, i, partials, v); v += stateCount; // TODO Add this functionality to SimpleSiteList to avoid if statement here } else { int state = patternList.getPatternState(sequenceIndex, i); stateSet = dataType.getStateSet(state); for (int j = 0; j < stateCount; j++) { if (stateSet[j]) { partials[v] = 1.0; } else { partials[v] = 0.0; } v++; } } } // TODO Note code duplication with BTL, refactor when debugged return partials; } private int[] getStates(PatternList patternList, int sequenceIndex) { int[] states = new int[patternCount]; for (int i = 0; i < patternCount; i++) { states[i] = patternList.getPatternState(sequenceIndex, i); } return states; } public BranchModel getBranchModel() { return branchModel; } protected Helper treeTraits = new Helper(); public TreeTrait[] getTreeTraits() { return treeTraits.getTreeTraits(); } public TreeTrait getTreeTrait(String key) { return treeTraits.getTreeTrait(key); } protected void handleModelChangedEvent(Model model, Object object, int index) { super.handleModelChangedEvent(model, object, index); fireModelChanged(model); } public int[] getStatesForNode(Tree tree, NodeRef node) { if (tree != treeModel) { throw new RuntimeException("Can only reconstruct states on treeModel given to constructor"); } if (!likelihoodKnown) { calculateLogLikelihood(); likelihoodKnown = true; } if (!areStatesRedrawn) { redrawAncestralStates(); } return reconstructedStates[node.getNumber()]; } @Override protected int getScaleBufferCount() { return internalNodeCount + 2; } private int drawChoice(double[] measure) { if (useMAP) { double max = measure[0]; int choice = 0; for (int i = 1; i < measure.length; i++) { if (measure[i] > max) { max = measure[i]; choice = i; } } return choice; } else { return MathUtils.randomChoicePDF(measure); } } public void makeDirty() { super.makeDirty(); areStatesRedrawn = false; } public void redrawAncestralStates() { // Sample states jointLogLikelihood = 0; traverseSample(treeModel, treeModel.getRoot(), null, null); areStatesRedrawn = true; } // protected double calculateLogLikelihood() { // areStatesRedrawn = false; // return super.calculateLogLikelihood(); // } protected double calculateLogLikelihood() { areStatesRedrawn = false; double marginalLogLikelihood = super.calculateLogLikelihood(); if (returnMarginalLogLikelihood) { return marginalLogLikelihood; } // redraw states and return joint density of drawn states redrawAncestralStates(); return jointLogLikelihood; } public String formattedState(int[] state) { return formattedState(state, dataType); } private static String formattedState(int[] state, DataType dataType) { StringBuffer sb = new StringBuffer(); sb.append("\""); if (dataType instanceof GeneralDataType) { boolean first = true; for (int i : state) { if (!first) { sb.append(" "); } else { first = false; } sb.append(dataType.getCode(i)); } } else { for (int i : state) { if (dataType.getClass().equals(Codons.class)) { sb.append(dataType.getTriplet(i)); } else { sb.append(dataType.getChar(i)); } } } sb.append("\""); return sb.toString(); } protected void getMatrix(int branchIndex, double[] probabilities) { beagle.getTransitionMatrix(substitutionModelDelegate.getMatrixIndex(branchIndex), probabilities); // NB: It may be faster to compute matrices in BEAST via substitutionModel } public void setTipStates(int tipNum, int[] states) { System.arraycopy(states, 0, tipStates[tipNum], 0, states.length); beagle.setTipStates(tipNum, states); makeDirty(); } // public void getTipPartials(int tipNum, double[] partials) { // System.arraycopy(tipPartials[tipNum], 0, partials, 0, partials.length); // } public void getTipStates(int tipNum, int[] states) { // Saved locally to reduce BEAGLE library access System.arraycopy(tipStates[tipNum], 0, states, 0, states.length); } // public int traverseCollectScaleBuffers(TreeModel tree, NodeRef node) { // // if (true) // Currently do nothing // return 0; // // return 0; // // int nodeNum = node.getNumber(); // // NodeRef child0 = tree.getChild(node,0); // NodeRef child1 = tree.getChild(node,1); // // int len0 = traverseCollectScaleBuffers(tree,child0); // int len1 = traverseCollectScaleBuffers(tree,child1); // int thisLen = len0 + len1 + 1; // int offset = 0; // // int[] scaleBuffer = new int[thisLen]; // if (len0 > 0) { // System.arraycopy(cumulativeScaleBuffers[child0.getNumber()],0,scaleBuffer,offset,len0); // offset += len0; // } // if (len1 > 0) { // System.arraycopy(cumulativeScaleBuffers[child1.getNumber()],0,scaleBuffer,offset,len1); // offset += len1; // } // scaleBuffer[offset] = scaleBufferHelper.getOffsetIndex(nodeNum - tipCount); // cumulativeScaleBuffers[nodeNum] = scaleBuffer; // // return thisLen; // } public void storeState() { super.storeState(); if (areStatesRedrawn) { for (int i = 0; i < reconstructedStates.length; i++) { System.arraycopy(reconstructedStates[i], 0, storedReconstructedStates[i], 0, reconstructedStates[i].length); } } // TODO MAS: I do not understand why these are NOT necessary storedAreStatesRedrawn = areStatesRedrawn; storedJointLogLikelihood = jointLogLikelihood; } public void restoreState() { super.restoreState(); int[][] temp = reconstructedStates; reconstructedStates = storedReconstructedStates; storedReconstructedStates = temp; areStatesRedrawn = storedAreStatesRedrawn; jointLogLikelihood = storedJointLogLikelihood; } public void traverseSample(TreeModel tree, NodeRef node, int[] parentState, int[] rateCategory) { int nodeNum = node.getNumber(); NodeRef parent = tree.getParent(node); // This function assumes that all partial likelihoods have already been calculated // If the node is internal, then sample its state given the state of its parent (pre-order traversal). double[] conditionalProbabilities = new double[stateCount]; int[] state = new int[patternCount]; if (!tree.isExternal(node)) { if (parent == null) { // This is the root node getPartials(nodeNum, partials); boolean sampleCategory = categoryCount > 1; double[] posteriorWeightedCategory = null; double[] priorWeightedCategory = null; if (sampleCategory) { rateCategory = new int[patternCount]; posteriorWeightedCategory = new double[categoryCount]; priorWeightedCategory = siteRateModel.getCategoryProportions(); } for (int j = 0; j < patternCount; j++) { // Sample across-site-rate-variation, if it exists if (sampleCategory) { for (int r = 0; r < categoryCount; r++) { posteriorWeightedCategory[r] = 0; for (int k = 0; k < stateCount; k++) { posteriorWeightedCategory[r] += partials[r * stateCount * patternCount + j * stateCount + k]; } posteriorWeightedCategory[r] *= priorWeightedCategory[r]; } rateCategory[j] = drawChoice(posteriorWeightedCategory); } // Sample root character state int partialsIndex = (rateCategory == null ? 0 : rateCategory[j]) * stateCount * patternCount; System.arraycopy(partials, partialsIndex + j * stateCount, conditionalProbabilities, 0, stateCount); double[] frequencies = substitutionModelDelegate.getRootStateFrequencies(); // TODO May have more than one set of frequencies for (int i = 0; i < stateCount; i++) { conditionalProbabilities[i] *= frequencies[i]; } try { state[j] = drawChoice(conditionalProbabilities); } catch (Error e) { System.err.println(e.toString()); System.err.println("Please report error to Marc"); state[j] = 0; } reconstructedStates[nodeNum][j] = state[j]; if (!returnMarginalLogLikelihood) { jointLogLikelihood += Math.log(frequencies[state[j]]); } } if (sampleCategory) { if (this.rateCategory == null) { this.rateCategory = new int[patternCount]; } System.arraycopy(rateCategory, 0, this.rateCategory, 0, patternCount); } } else { // This is an internal node, but not the root double[] partialLikelihood = new double[stateCount * patternCount * categoryCount]; getPartials(nodeNum, partialLikelihood); // Sibon says that this actually works now // if (categoryCount > 1) // throw new RuntimeException("Reconstruction not implemented for multiple categories yet."); getMatrix(nodeNum, probabilities); for (int j = 0; j < patternCount; j++) { int parentIndex = parentState[j] * stateCount; int childIndex = j * stateCount; int category = rateCategory == null ? 0 : rateCategory[j]; int matrixIndex = category * stateCount * stateCount; int partialIndex = category * stateCount * patternCount; for (int i = 0; i < stateCount; i++) conditionalProbabilities[i] = partialLikelihood[partialIndex + childIndex + i] * probabilities[matrixIndex + parentIndex + i]; state[j] = drawChoice(conditionalProbabilities); reconstructedStates[nodeNum][j] = state[j]; if (!returnMarginalLogLikelihood) { double contrib = probabilities[parentIndex + state[j]]; jointLogLikelihood += Math.log(contrib); } } hookCalculation(tree, parent, node, parentState, state, probabilities, rateCategory); } // Traverse down the two child nodes NodeRef child1 = tree.getChild(node, 0); traverseSample(tree, child1, state, rateCategory); NodeRef child2 = tree.getChild(node, 1); traverseSample(tree, child2, state, rateCategory); } else { // This is an external leaf if (useAmbiguities()) { getMatrix(nodeNum, probabilities); double[] partials = tipPartials[nodeNum]; for (int j = 0; j < patternCount; j++) { final int parentIndex = parentState[j] * stateCount; int category = rateCategory == null ? 0 : rateCategory[j]; int matrixIndex = category * stateCount * stateCount; System.arraycopy(probabilities, parentIndex + matrixIndex, conditionalProbabilities, 0, stateCount); for (int k = 0; k < stateCount; ++k) { conditionalProbabilities[k] *= partials[j * stateCount + k]; } reconstructedStates[nodeNum][j] = drawChoice(conditionalProbabilities); if (!returnMarginalLogLikelihood) { double contrib = probabilities[parentIndex + reconstructedStates[nodeNum][j]]; jointLogLikelihood += Math.log(contrib); } } } else { getTipStates(nodeNum, reconstructedStates[nodeNum]); // Check for ambiguity codes and sample them for (int j = 0; j < patternCount; j++) { final int thisState = reconstructedStates[nodeNum][j]; if (dataType.isAmbiguousState(thisState)) { final int parentIndex = parentState[j] * stateCount; int category = rateCategory == null ? 0 : rateCategory[j]; int matrixIndex = category * stateCount * stateCount; getMatrix(nodeNum, probabilities); System.arraycopy(probabilities, parentIndex + matrixIndex, conditionalProbabilities, 0, stateCount); if (useAmbiguities && !dataType.isUnknownState(thisState)) { // Not completely unknown boolean[] stateSet = dataType.getStateSet(thisState); for (int k = 0; k < stateCount; k++) { if (!stateSet[k]) { conditionalProbabilities[k] = 0.0; } } } reconstructedStates[nodeNum][j] = drawChoice(conditionalProbabilities); } if (!returnMarginalLogLikelihood) { final int parentIndex = parentState[j] * stateCount; getMatrix(nodeNum, probabilities); if (!returnMarginalLogLikelihood) { double contrib = probabilities[parentIndex + reconstructedStates[nodeNum][j]]; jointLogLikelihood += Math.log(contrib); } } } } hookCalculation(tree, parent, node, parentState, reconstructedStates[nodeNum], null, rateCategory); } } protected void hookCalculation(Tree tree, NodeRef parentNode, NodeRef childNode, int[] parentStates, int[] childStates, double[] probabilities, int[] rateCategory) { // Do nothing } private final DataType dataType; private int[][] reconstructedStates; private int[][] storedReconstructedStates; // private final String tag; protected boolean areStatesRedrawn = false; protected boolean storedAreStatesRedrawn = false; private boolean useMAP = false; private boolean returnMarginalLogLikelihood = true; private double jointLogLikelihood; private double storedJointLogLikelihood; private int[][] tipStates; private double[][] tipPartials; private double[] probabilities; private double[] partials; protected int[] rateCategory = null; // private double[] rootPartials; // private int[][] cumulativeScaleBuffers; // private int scaleBufferIndex; }