/* * AbstractObservationProcess.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.oldevomodel.MSSD; import dr.evolution.alignment.AscertainedSitePatterns; import dr.evolution.alignment.PatternList; import dr.evolution.datatype.MutationDeathType; import dr.evolution.tree.NodeRef; import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.branchratemodel.DefaultBranchRateModel; import dr.oldevomodel.sitemodel.SiteRateModel; import dr.evomodel.tree.TreeModel; import dr.oldevomodel.treelikelihood.LikelihoodCore; import dr.evomodel.treelikelihood.LikelihoodPartialsProvider; import dr.oldevomodel.treelikelihood.ScaleFactorsHelper; import dr.inference.model.AbstractModel; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.math.GammaFunction; /** * Package: AbstractObservationProcess * Description: * <p/> * <p/> * Created by * Alexander V. Alekseyenko (alexander.alekseyenko@gmail.com) * Date: Feb 19, 2008 * Time: 12:41:01 PM */ abstract public class AbstractObservationProcess extends AbstractModel { protected boolean[] nodePatternInclusion; protected boolean[] storedNodePatternInclusion; protected double[] cumLike; protected double[] nodePartials; protected double[] nodeLikelihoods; protected int nodeCount; protected int patternCount; protected int stateCount; protected TreeModel treeModel; protected PatternList patterns; protected double[] patternWeights; protected Parameter mu; protected Parameter lam; // update control variables protected boolean weightKnown; protected double logTreeWeight; protected double storedLogTreeWeight; private double gammaNorm; private double totalPatterns; protected MutationDeathType dataType; protected int deathState; protected SiteRateModel siteModel; private double logN; protected boolean nodePatternInclusionKnown = false; BranchRateModel branchRateModel; public AbstractObservationProcess(String Name, TreeModel treeModel, PatternList patterns, SiteRateModel siteModel, BranchRateModel branchRateModel, Parameter mu, Parameter lam) { super(Name); this.treeModel = treeModel; this.patterns = patterns; this.mu = mu; this.lam = lam; this.siteModel = siteModel; if (branchRateModel != null) { this.branchRateModel = branchRateModel; } else { this.branchRateModel = new DefaultBranchRateModel(); } addModel(treeModel); addModel(siteModel); addModel(this.branchRateModel); addVariable(mu); addVariable(lam); nodeCount = treeModel.getNodeCount(); stateCount = patterns.getDataType().getStateCount(); this.patterns = patterns; patternCount = patterns.getPatternCount(); patternWeights = patterns.getPatternWeights(); totalPatterns = 0; for (int i = 0; i < patternCount; ++i) { totalPatterns += patternWeights[i]; } logN = Math.log(totalPatterns); gammaNorm = -GammaFunction.lnGamma(totalPatterns + 1); dataType = (MutationDeathType) patterns.getDataType(); this.deathState = dataType.DEATHSTATE; setNodePatternInclusion(); cumLike = new double[patternCount]; nodeLikelihoods = new double[patternCount]; weightKnown = false; } // public Parameter getMuParameter() { // return mu; // } // // public Parameter getLamParameter() { // return lam; // } private double calculateSiteLogLikelihood(int site, double[] partials, double[] frequencies) { int v = site * stateCount; double sum = 0.0; for (int i = 0; i < stateCount; i++) { sum += frequencies[i] * partials[v + i]; } return Math.log(sum); } private void calculateNodePatternLikelihood(int nodeIndex, double[] freqs, LikelihoodCore likelihoodCore, double averageRate, double[] cumLike) { // get partials for node nodeIndex likelihoodCore.getPartials(nodeIndex, nodePartials); // MAS /* multiply the partials by equilibrium probs this part could be optimized by first summing and then multiplying by equilibrium probs */ double prob = Math.log(getNodeSurvivalProbability(nodeIndex, averageRate)); for (int j = 0; j < patternCount; ++j) { if (nodePatternInclusion[nodeIndex * patternCount + j]) { cumLike[j] += Math.exp(calculateSiteLogLikelihood(j, nodePartials, freqs) + prob); } } } private double accumulateCorrectedLikelihoods(double[] cumLike, double ascertainmentCorrection, double[] patterWeights) { double logL = 0; for (int j = 0; j < patternCount; ++j) { logL += Math.log(cumLike[j] / ascertainmentCorrection) * patternWeights[j]; } return logL; } public final double nodePatternLikelihood(double[] freqs, LikelihoodPartialsProvider likelihoodCore, ScaleFactorsHelper scaleFactorsHelper) { int i, j; double logL = gammaNorm; double birthRate = lam.getParameterValue(0); double logProb; if (!nodePatternInclusionKnown) setNodePatternInclusion(); if (nodePartials == null) { nodePartials = new double[patternCount * stateCount]; } double averageRate = getAverageRate(); for (j = 0; j < patternCount; ++j) cumLike[j] = 0; for (i = 0; i < nodeCount; ++i) { // get partials for node i likelihoodCore.getPartials(i, nodePartials); scaleFactorsHelper.rescalePartials(i, nodePartials); /* multiply the partials by equilibrium probs this part could be optimized by first summing and then multiplying by equilibrium probs */ // likelihoodCore.calculateLogLikelihoods(nodePartials, freqs, nodeLikelihoods); // MAS Removed logProb = Math.log(getNodeSurvivalProbability(i, averageRate)); for (j = 0; j < patternCount; ++j) { if (nodePatternInclusion[i * patternCount + j]) { // cumLike[j] += Math.exp(nodeLikelihoods[j] + logProb); // MAS Replaced with line below cumLike[j] += Math.exp(calculateSiteLogLikelihood(j, nodePartials, freqs) + logProb); } } } double ascertainmentCorrection = getAscertainmentCorrection(cumLike); // System.err.println("AscertainmentCorrection: "+ascertainmentCorrection); for (j = 0; j < patternCount; ++j) { logL += Math.log(cumLike[j] / ascertainmentCorrection) * patternWeights[j]; } double deathRate = mu.getParameterValue(0); double logTreeWeight = getLogTreeWeight(); if (integrateGainRate) { logL -= gammaNorm + logN + Math.log(-logTreeWeight * deathRate / birthRate) * totalPatterns; } else { logL += logTreeWeight + Math.log(birthRate / deathRate) * totalPatterns; } return logL; } protected double getAscertainmentCorrection(double[] patternProbs) { // This function probably belongs better to the AscertainedSitePatterns double excludeProb = 0, includeProb = 0, returnProb = 1.0; if (this.patterns instanceof AscertainedSitePatterns) { int[] includeIndices = ((AscertainedSitePatterns) patterns).getIncludePatternIndices(); int[] excludeIndices = ((AscertainedSitePatterns) patterns).getExcludePatternIndices(); for (int i = 0; i < ((AscertainedSitePatterns) patterns).getIncludePatternCount(); i++) { int index = includeIndices[i]; includeProb += patternProbs[index]; } for (int j = 0; j < ((AscertainedSitePatterns) patterns).getExcludePatternCount(); j++) { int index = excludeIndices[j]; excludeProb += patternProbs[index]; } if (includeProb == 0.0) { returnProb -= excludeProb; } else if (excludeProb == 0.0) { returnProb = includeProb; } else { returnProb = includeProb - excludeProb; } } return returnProb; } final public double getLogTreeWeight() { if (!weightKnown) { logTreeWeight = calculateLogTreeWeight(); weightKnown = true; } return logTreeWeight; } abstract public double calculateLogTreeWeight(); abstract void setNodePatternInclusion(); final public double getAverageRate() { if (!averageRateKnown) { double avgRate = 0.0; double proportions[] = siteModel.getCategoryProportions(); for (int i = 0; i < siteModel.getCategoryCount(); ++i) { avgRate += proportions[i] * siteModel.getRateForCategory(i); } averageRate = avgRate; averageRateKnown = true; } return averageRate; } public double getNodeSurvivalProbability(int index, double averageRate) { NodeRef node = treeModel.getNode(index); NodeRef parent = treeModel.getParent(node); if (parent == null) return 1.0; final double deathRate = mu.getParameterValue(0) * averageRate; //getAverageRate(); final double branchRate = branchRateModel.getBranchRate(treeModel, node); // Get the operational time of the branch final double branchTime = branchRate * treeModel.getBranchLength(node); return 1.0 - Math.exp(-deathRate * branchTime); } protected void handleModelChangedEvent(Model model, Object object, int index) { if (model == siteModel) { averageRateKnown = false; } if (model == treeModel || model == siteModel || model == branchRateModel) { weightKnown = false; } if (model == treeModel) { if (object instanceof TreeModel.TreeChangedEvent) { if (((TreeModel.TreeChangedEvent) object).isTreeChanged()) { nodePatternInclusionKnown = false; } } } } protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { if (variable == mu || variable == lam) { weightKnown = false; } else { System.err.println("AbstractObservationProcess: Got unexpected parameter changed event. (Parameter = " + variable + ")"); } } protected void storeState() { // storedAverageRate = averageRate; storedLogTreeWeight = logTreeWeight; System.arraycopy(nodePatternInclusion, 0, storedNodePatternInclusion, 0, storedNodePatternInclusion.length); } protected void restoreState() { // averageRate = storedAverageRate; averageRateKnown = false; logTreeWeight = storedLogTreeWeight; boolean[] tmp = storedNodePatternInclusion; storedNodePatternInclusion = nodePatternInclusion; nodePatternInclusion = tmp; } protected void acceptState() { } public void setIntegrateGainRate(boolean integrateGainRate) { this.integrateGainRate = integrateGainRate; } private boolean integrateGainRate = false; private double storedAverageRate; private double averageRate; private boolean averageRateKnown = false; }