/* * PrecisionTestTreeLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.oldevomodel.treelikelihood; import dr.evolution.alignment.PatternList; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.util.TaxonList; import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.branchratemodel.DefaultBranchRateModel; import dr.oldevomodel.sitemodel.SiteModel; import dr.oldevomodel.substmodel.FrequencyModel; import dr.evomodel.tree.TreeModel; import dr.inference.model.Likelihood; import dr.inference.model.Model; import dr.xml.*; import java.math.BigDecimal; import java.util.logging.Logger; /** * TreeLikelihoodModel - implements a Likelihood Function for sequences on a tree. * * @author Andrew Rambaut * @author Alexei Drummond * @version $Id: TreeLikelihood.java,v 1.31 2006/08/30 16:02:42 rambaut Exp $ */ @Deprecated // Switching to BEAGLE public class PrecisionTestTreeLikelihood extends AbstractTreeLikelihood { public static final String TREE_LIKELIHOOD = "precisionTestTreeLikelihood"; public static final String USE_AMBIGUITIES = "useAmbiguities"; public static final String ALLOW_MISSING_TAXA = "allowMissingTaxa"; public static final String STORE_PARTIALS = "storePartials"; public static final String USE_SCALING = "useScaling"; public static final String FORCE_JAVA_CORE = "forceJavaCore"; /** * Constructor. */ public PrecisionTestTreeLikelihood(PatternList patternList, TreeModel treeModel, SiteModel siteModel, BranchRateModel branchRateModel, boolean useAmbiguities, boolean allowMissingTaxa, boolean storePartials) { super(TREE_LIKELIHOOD, patternList, treeModel); this.storePartials = storePartials; try { this.siteModel = siteModel; addModel(siteModel); this.frequencyModel = siteModel.getFrequencyModel(); addModel(frequencyModel); this.categoryCount = siteModel.getCategoryCount(); final Logger logger = Logger.getLogger("dr.evomodel"); String coreName = "Java general"; likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount()); precisionLikelihoodCore = new ArbitraryPrecisionLikelihoodCore(patternList.getStateCount(), 20); logger.info("PrecisionTestTreeLikelihood using " + coreName + " likelihood core"); logger.info(" " + (useAmbiguities ? "Using" : "Ignoring") + " ambiguities in tree likelihood."); if (branchRateModel != null) { this.branchRateModel = branchRateModel; logger.info("Branch rate model used: " + branchRateModel.getModelName()); } else { this.branchRateModel = new DefaultBranchRateModel(); } addModel(this.branchRateModel); probabilities = new double[stateCount * stateCount]; likelihoodCore.initialize(nodeCount, patternCount, categoryCount, true); precisionLikelihoodCore.initialize(nodeCount, patternCount, categoryCount, true); int extNodeCount = treeModel.getExternalNodeCount(); int intNodeCount = treeModel.getInternalNodeCount(); for (int i = 0; i < extNodeCount; i++) { // Find the id of tip i in the patternList String id = treeModel.getTaxonId(i); int index = patternList.getTaxonIndex(id); if (index == -1) { if (!allowMissingTaxa) { throw new TaxonList.MissingTaxonException("Taxon, " + id + ", in tree, " + treeModel.getId() + ", is not found in patternList, " + patternList.getId()); } if (useAmbiguities) { setMissingPartials(likelihoodCore, i); setMissingPartials(precisionLikelihoodCore, i); } else { setMissingStates(likelihoodCore, i); setMissingStates(precisionLikelihoodCore, i); } } else { if (useAmbiguities) { setPartials(likelihoodCore, patternList, categoryCount, index, i); setPartials(precisionLikelihoodCore, patternList, categoryCount, index, i); } else { setStates(likelihoodCore, patternList, index, i); setStates(precisionLikelihoodCore, patternList, index, i); } } } for (int i = 0; i < intNodeCount; i++) { likelihoodCore.createNodePartials(extNodeCount + i); precisionLikelihoodCore.createNodePartials(extNodeCount + i); } } catch (TaxonList.MissingTaxonException mte) { throw new RuntimeException(mte.toString()); } } // ************************************************************** // ModelListener IMPLEMENTATION // ************************************************************** /** * Handles model changed events from the submodels. */ protected void handleModelChangedEvent(Model model, Object object, int index) { if (model == treeModel) { if (object instanceof TreeModel.TreeChangedEvent) { if (((TreeModel.TreeChangedEvent) object).isNodeChanged()) { // If a node event occurs the node and its two child nodes // are flagged for updating (this will result in everything // above being updated as well. Node events occur when a node // is added to a branch, removed from a branch or its height or // rate changes. updateNodeAndChildren(((TreeModel.TreeChangedEvent) object).getNode()); } else if (((TreeModel.TreeChangedEvent) object).isTreeChanged()) { // Full tree events result in a complete updating of the tree likelihood // Currently this event type is not used. System.err.println("Full tree update event - these events currently aren't used\n" + "so either this is in error or a new feature is using them so remove this message."); updateAllNodes(); } else { // Other event types are ignored (probably trait changes). //System.err.println("Another tree event has occured (possibly a trait change)."); } } } else if (model == branchRateModel) { if (index == -1) { updateAllNodes(); } else { updateNode(treeModel.getNode(index)); } } else if (model == frequencyModel) { updateAllNodes(); } else if (model instanceof SiteModel) { updateAllNodes(); } else { throw new RuntimeException("Unknown componentChangedEvent"); } super.handleModelChangedEvent(model, object, index); } // ************************************************************** // Model IMPLEMENTATION // ************************************************************** /** * Stores the additional state other than model components */ protected void storeState() { if (storePartials) { likelihoodCore.storeState(); precisionLikelihoodCore.storeState(); } super.storeState(); } /** * Restore the additional stored state */ protected void restoreState() { if (storePartials) { likelihoodCore.restoreState(); precisionLikelihoodCore.restoreState(); } else { updateAllNodes(); } super.restoreState(); } // ************************************************************** // Likelihood IMPLEMENTATION // ************************************************************** /** * Calculate the log likelihood of the current state. * * @return the log likelihood. */ protected double calculateLogLikelihood() { if (patternLogLikelihoods == null) { patternLogLikelihoods = new double[patternCount]; } if (precisionPatternLogLikelihoods == null) { precisionPatternLogLikelihoods = new double[patternCount]; } final NodeRef root = treeModel.getRoot(); traverse(treeModel, root); double logL = 0.0; for (int i = 0; i < patternCount; i++) { logL += patternLogLikelihoods[i] * patternWeights[i]; } double precisionLogL = 0.0; for (int i = 0; i < patternCount; i++) { precisionLogL += precisionPatternLogLikelihoods[i] * patternWeights[i]; } if (Math.abs(logL - precisionLogL) > 1.0E-5) { System.out.println("logL = " + logL + " precision logL = " + precisionLogL); } if (logL == Double.NEGATIVE_INFINITY) { // We probably had an underflow... turn on scaling likelihoodCore.setUseScaling(true); // and try again... updateAllNodes(); updateAllPatterns(); traverse(treeModel, root); logL = 0.0; for (int i = 0; i < patternCount; i++) { logL += patternLogLikelihoods[i] * patternWeights[i]; } if (Math.abs(logL - precisionLogL) > 1.0E-5) { System.out.println("scaled logL = " + logL + " precision logL = " + precisionLogL); } } //******************************************************************** // after traverse all nodes and patterns have been updated -- //so change flags to reflect this. for (int i = 0; i < nodeCount; i++) { updateNode[i] = false; } //******************************************************************** return logL; } /** * Traverse the tree calculating partial likelihoods. * * @return whether the partials for this node were recalculated. */ private boolean traverse(Tree tree, NodeRef node) { boolean update = false; int nodeNum = node.getNumber(); NodeRef parent = tree.getParent(node); // First update the transition probability matrix(ices) for this branch if (parent != null && updateNode[nodeNum]) { final double branchRate = branchRateModel.getBranchRate(tree, node); // Get the operational time of the branch final double branchTime = branchRate * (tree.getNodeHeight(parent) - tree.getNodeHeight(node)); if (branchTime < 0.0) { throw new RuntimeException("Negative branch length: " + branchTime); } likelihoodCore.setNodeMatrixForUpdate(nodeNum); precisionLikelihoodCore.setNodeMatrixForUpdate(nodeNum); for (int i = 0; i < categoryCount; i++) { double branchLength = siteModel.getRateForCategory(i) * branchTime; siteModel.getSubstitutionModel().getTransitionProbabilities(branchLength, probabilities); likelihoodCore.setNodeMatrix(nodeNum, i, probabilities); precisionLikelihoodCore.setNodeMatrix(nodeNum, i, probabilities); } update = true; } // If the node is internal, update the partial likelihoods. if (!tree.isExternal(node)) { // Traverse down the two child nodes NodeRef child1 = tree.getChild(node, 0); final boolean update1 = traverse(tree, child1); NodeRef child2 = tree.getChild(node, 1); final boolean update2 = traverse(tree, child2); // If either child node was updated then update this node too if (update1 || update2) { final int childNum1 = child1.getNumber(); final int childNum2 = child2.getNumber(); likelihoodCore.setNodePartialsForUpdate(nodeNum); precisionLikelihoodCore.setNodePartialsForUpdate(nodeNum); likelihoodCore.calculatePartials(childNum1, childNum2, nodeNum); precisionLikelihoodCore.calculatePartials(childNum1, childNum2, nodeNum); if (parent == null) { // No parent this is the root of the tree - // calculate the pattern likelihoods double[] frequencies = frequencyModel.getFrequencies(); double[] partials = getRootPartials(); likelihoodCore.calculateLogLikelihoods(partials, frequencies, patternLogLikelihoods); BigDecimal[] precisionPartials = getPrecisionRootPartials(); precisionLikelihoodCore.calculateLogLikelihoods(precisionPartials, frequencies, precisionPatternLogLikelihoods); } update = true; } } return update; } public final double[] getRootPartials() { if (rootPartials == null) { rootPartials = new double[patternCount * stateCount]; } int nodeNum = treeModel.getRoot().getNumber(); // moved this call to here, because non-integrating siteModels don't need to support it - AD double[] proportions = siteModel.getCategoryProportions(); likelihoodCore.integratePartials(nodeNum, proportions, rootPartials); return rootPartials; } private double[] rootPartials = null; public final BigDecimal[] getPrecisionRootPartials() { if (precisionRootPartials == null) { precisionRootPartials = new BigDecimal[patternCount * stateCount]; } int nodeNum = treeModel.getRoot().getNumber(); // moved this call to here, because non-integrating siteModels don't need to support it - AD double[] proportions = siteModel.getCategoryProportions(); precisionLikelihoodCore.integratePartials(nodeNum, proportions, precisionRootPartials); return precisionRootPartials; } private BigDecimal[] precisionRootPartials = null; /** * The XML parser */ public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserName() { return TREE_LIKELIHOOD; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean useAmbiguities = xo.getAttribute(USE_AMBIGUITIES, false); boolean allowMissingTaxa = xo.getAttribute(ALLOW_MISSING_TAXA, false); boolean storePartials = xo.getAttribute(STORE_PARTIALS, true); PatternList patternList = (PatternList) xo.getChild(PatternList.class); TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class); SiteModel siteModel = (SiteModel) xo.getChild(SiteModel.class); BranchRateModel branchRateModel = (BranchRateModel) xo.getChild(BranchRateModel.class); return new PrecisionTestTreeLikelihood( patternList, treeModel, siteModel, branchRateModel, useAmbiguities, allowMissingTaxa, storePartials); } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "This element represents the likelihood of a patternlist on a tree given the site model."; } public Class getReturnType() { return Likelihood.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{ AttributeRule.newBooleanRule(USE_AMBIGUITIES, true), AttributeRule.newBooleanRule(ALLOW_MISSING_TAXA, true), AttributeRule.newBooleanRule(STORE_PARTIALS, true), new ElementRule(PatternList.class), new ElementRule(TreeModel.class), new ElementRule(SiteModel.class), new ElementRule(BranchRateModel.class, true), }; }; // ************************************************************** // INSTANCE VARIABLES // ************************************************************** /** * the frequency model for these sites */ protected final FrequencyModel frequencyModel; /** * the site model for these sites */ protected final SiteModel siteModel; /** * the branch rate model */ protected final BranchRateModel branchRateModel; private final boolean storePartials; /** * the categories for each site */ protected int[] siteCategories = null; /** * the pattern likelihoods */ protected double[] patternLogLikelihoods = null; /** * the pattern likelihoods */ protected double[] precisionPatternLogLikelihoods = null; /** * the number of rate categories */ protected int categoryCount; /** * an array used to store transition probabilities */ protected double[] probabilities; /** * the LikelihoodCore */ protected LikelihoodCore likelihoodCore; protected ArbitraryPrecisionLikelihoodCore precisionLikelihoodCore; }