/* * GMRFSkyrideLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.coalescent; import dr.evolution.coalescent.IntervalType; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evomodel.tree.TreeModel; import dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser; import dr.inference.model.MatrixParameter; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.math.MathUtils; import dr.util.Author; import dr.util.Citable; import dr.util.Citation; import no.uib.cipr.matrix.DenseVector; import no.uib.cipr.matrix.NotConvergedException; import no.uib.cipr.matrix.SymmTridiagEVD; import no.uib.cipr.matrix.SymmTridiagMatrix; import java.util.ArrayList; import java.util.Collections; import java.util.List; /** * A likelihood function for a Gaussian Markov random field on a log population size trajectory. * * @author Jen Tom * @author Erik Bloomquist * @author Vladimir Minin * @author Marc Suchard * @version $Id: GMRFSkylineLikelihood.java,v 1.3 2007/03/20 22:40:04 msuchard Exp $ */ public class GMRFSkyrideLikelihood extends OldAbstractCoalescentLikelihood implements CoalescentIntervalProvider, Citable { // PUBLIC STUFF public static final double LOG_TWO_TIMES_PI = 1.837877; public static final boolean TIME_AWARE_IS_ON_BY_DEFAULT = true; // PRIVATE STUFF protected Parameter popSizeParameter; protected Parameter groupSizeParameter; protected Parameter precisionParameter; protected Parameter lambdaParameter; protected Parameter betaParameter; // protected double[] gmrfWeights; protected int fieldLength; protected double[] coalescentIntervals; protected double[] storedCoalescentIntervals; protected double[] sufficientStatistics; protected double[] storedSufficientStatistics; //changed from private to protected protected double logFieldLikelihood; protected double storedLogFieldLikelihood; protected SymmTridiagMatrix weightMatrix; protected SymmTridiagMatrix storedWeightMatrix; protected MatrixParameter dMatrix; protected boolean timeAwareSmoothing = TIME_AWARE_IS_ON_BY_DEFAULT; protected boolean rescaleByRootHeight; public GMRFSkyrideLikelihood() { super(GMRFSkyrideLikelihoodParser.SKYLINE_LIKELIHOOD); } public GMRFSkyrideLikelihood(String name) { super(name); } public GMRFSkyrideLikelihood(Tree tree, Parameter popParameter, Parameter groupParameter, Parameter precParameter, Parameter lambda, Parameter beta, MatrixParameter dMatrix, boolean timeAwareSmoothing, boolean rescaleByRootHeight) { this(wrapTree(tree), popParameter, groupParameter, precParameter, lambda, beta, dMatrix, timeAwareSmoothing, rescaleByRootHeight); } private static List<Tree> wrapTree(Tree tree) { List<Tree> treeList = new ArrayList<Tree>(); treeList.add(tree); return treeList; } public GMRFSkyrideLikelihood(List<Tree> treeList, Parameter popParameter, Parameter groupParameter, Parameter precParameter, Parameter lambda, Parameter beta, MatrixParameter dMatrix, boolean timeAwareSmoothing, boolean rescaleByRootHeight) { super(GMRFSkyrideLikelihoodParser.SKYLINE_LIKELIHOOD); this.popSizeParameter = popParameter; this.groupSizeParameter = groupParameter; this.precisionParameter = precParameter; this.lambdaParameter = lambda; this.betaParameter = beta; this.dMatrix = dMatrix; this.timeAwareSmoothing = timeAwareSmoothing; this.rescaleByRootHeight = rescaleByRootHeight; addVariable(popSizeParameter); addVariable(precisionParameter); addVariable(lambdaParameter); if (betaParameter != null) { addVariable(betaParameter); } setTree(treeList); int correctFieldLength = getCorrectFieldLength(); if (popSizeParameter.getDimension() <= 1) { // popSize dimension hasn't been set yet, set it here: popSizeParameter.setDimension(correctFieldLength); } fieldLength = popSizeParameter.getDimension(); if (correctFieldLength != fieldLength) { throw new IllegalArgumentException("Population size parameter should have length " + correctFieldLength); } // Field length must be set by this point wrapSetupIntervals(); coalescentIntervals = new double[fieldLength]; storedCoalescentIntervals = new double[fieldLength]; sufficientStatistics = new double[fieldLength]; storedSufficientStatistics = new double[fieldLength]; setupGMRFWeights(); addStatistic(new DeltaStatistic()); initializationReport(); /* Force all entries in groupSizeParameter = 1 for compatibility with Tracer */ if (groupSizeParameter != null) { for (int i = 0; i < groupSizeParameter.getDimension(); i++) groupSizeParameter.setParameterValue(i, 1.0); } } protected int getCorrectFieldLength() { return tree.getExternalNodeCount() - 1; } protected void wrapSetupIntervals() { setupIntervals(); } protected void setTree(List<Tree> treeList) { if (treeList.size() != 1) { throw new RuntimeException("GMRFSkyrideLikelihood only implemented for one tree"); } this.tree = treeList.get(0); this.treesSet = null; if (tree instanceof TreeModel) { addModel((TreeModel) tree); } } // public double[] getCopyOfCoalescentIntervals() { // return coalescentIntervals.clone(); // } // // public double[] getCoalescentIntervals() { // return coalescentIntervals; // } public void initializationReport() { System.out.println("Creating a GMRF smoothed skyride model:"); System.out.println("\tPopulation sizes: " + popSizeParameter.getDimension()); System.out.println("\tIf you publish results using this model, please reference: Minin, Bloomquist and Suchard (2008) Molecular Biology and Evolution, 25, 1459-1471."); } public static void checkTree(TreeModel treeModel) { // todo Should only be run if there exists a zero-length interval // TreeModel treeModel = (TreeModel) tree; for (int i = 0; i < treeModel.getInternalNodeCount(); i++) { NodeRef node = treeModel.getInternalNode(i); if (node != treeModel.getRoot()) { double parentHeight = treeModel.getNodeHeight(treeModel.getParent(node)); double childHeight0 = treeModel.getNodeHeight(treeModel.getChild(node, 0)); double childHeight1 = treeModel.getNodeHeight(treeModel.getChild(node, 1)); double maxChild = childHeight0; if (childHeight1 > maxChild) maxChild = childHeight1; double newHeight = maxChild + MathUtils.nextDouble() * (parentHeight - maxChild); treeModel.setNodeHeight(node, newHeight); } } treeModel.pushTreeChangedEvent(); } // ************************************************************** // Likelihood IMPLEMENTATION // ************************************************************** public double getLogLikelihood() { if (!likelihoodKnown) { logLikelihood = calculateLogCoalescentLikelihood(); logFieldLikelihood = calculateLogFieldLikelihood(); likelihoodKnown = true; } return logLikelihood + logFieldLikelihood; } protected double peakLogCoalescentLikelihood() { return logLikelihood; } protected double peakLogFieldLikelihood() { return logFieldLikelihood; } public double[] getSufficientStatistics() { return sufficientStatistics; } public String toString() { return getId() + "(" + Double.toString(getLogLikelihood()) + ")"; } protected void setupSufficientStatistics() { int index = 0; double length = 0; double weight = 0; for (int i = 0; i < getIntervalCount(); i++) { length += getInterval(i); weight += getInterval(i) * getLineageCount(i) * (getLineageCount(i) - 1); if (getIntervalType(i) == CoalescentEventType.COALESCENT) { coalescentIntervals[index] = length; sufficientStatistics[index] = weight / 2.0; index++; length = 0; weight = 0; } } } protected double getFieldScalar() { final double rootHeight; if (rescaleByRootHeight) { rootHeight = tree.getNodeHeight(tree.getRoot()); } else { rootHeight = 1.0; } return rootHeight; } protected void setupGMRFWeights() { setupSufficientStatistics(); //Set up the weight Matrix double[] offdiag = new double[fieldLength - 1]; double[] diag = new double[fieldLength]; //First set up the offdiagonal entries; if (!timeAwareSmoothing) { for (int i = 0; i < fieldLength - 1; i++) { offdiag[i] = -1.0; } } else { for (int i = 0; i < fieldLength - 1; i++) { offdiag[i] = -2.0 / (coalescentIntervals[i] + coalescentIntervals[i + 1]) * getFieldScalar(); } } //Then set up the diagonal entries; for (int i = 1; i < fieldLength - 1; i++) diag[i] = -(offdiag[i] + offdiag[i - 1]); //Take care of the endpoints diag[0] = -offdiag[0]; diag[fieldLength - 1] = -offdiag[fieldLength - 2]; weightMatrix = new SymmTridiagMatrix(diag, offdiag); } public SymmTridiagMatrix getScaledWeightMatrix(double precision) { SymmTridiagMatrix a = weightMatrix.copy(); for (int i = 0; i < a.numRows() - 1; i++) { a.set(i, i, a.get(i, i) * precision); a.set(i + 1, i, a.get(i + 1, i) * precision); } a.set(fieldLength - 1, fieldLength - 1, a.get(fieldLength - 1, fieldLength - 1) * precision); return a; } public SymmTridiagMatrix getStoredScaledWeightMatrix(double precision) { SymmTridiagMatrix a = storedWeightMatrix.copy(); for (int i = 0; i < a.numRows() - 1; i++) { a.set(i, i, a.get(i, i) * precision); a.set(i + 1, i, a.get(i + 1, i) * precision); } a.set(fieldLength - 1, fieldLength - 1, a.get(fieldLength - 1, fieldLength - 1) * precision); return a; } public SymmTridiagMatrix getScaledWeightMatrix(double precision, double lambda) { if (lambda == 1) return getScaledWeightMatrix(precision); SymmTridiagMatrix a = weightMatrix.copy(); for (int i = 0; i < a.numRows() - 1; i++) { a.set(i, i, precision * (1 - lambda + lambda * a.get(i, i))); a.set(i + 1, i, a.get(i + 1, i) * precision * lambda); } a.set(fieldLength - 1, fieldLength - 1, precision * (1 - lambda + lambda * a.get(fieldLength - 1, fieldLength - 1))); return a; } private void makeIntervalsKnown() { if (!intervalsKnown) { wrapSetupIntervals(); setupGMRFWeights(); intervalsKnown = true; } } public int getCoalescentIntervalDimension() { makeIntervalsKnown(); return coalescentIntervals.length; } public double getCoalescentInterval(int i) { makeIntervalsKnown(); return coalescentIntervals[i]; } /*public int getCoalescentIntervalLineageCount(int i) { throw new RuntimeException("Not yet implemented"); } public IntervalType getCoalescentIntervalType(int i) { throw new RuntimeException("getCoalescentIntervalType(int i) in GMRFSkyrideLikelihood not yet implemented"); }*/ public int getNumberOfCoalescentEvents() { return tree.getExternalNodeCount() - 1; } public double getCoalescentEventsStatisticValue(int i) { return sufficientStatistics[i]; } public double[] getCoalescentIntervalHeights() { makeIntervalsKnown(); double[] a = new double[coalescentIntervals.length]; a[0] = coalescentIntervals[0]; for (int i = 1; i < a.length; i++) { a[i] = a[i - 1] + coalescentIntervals[i]; } return a; } public SymmTridiagMatrix getCopyWeightMatrix() { return weightMatrix.copy(); } public SymmTridiagMatrix getStoredScaledWeightMatrix(double precision, double lambda) { if (lambda == 1) return getStoredScaledWeightMatrix(precision); SymmTridiagMatrix a = storedWeightMatrix.copy(); for (int i = 0; i < a.numRows() - 1; i++) { a.set(i, i, precision * (1 - lambda + lambda * a.get(i, i))); a.set(i + 1, i, a.get(i + 1, i) * precision * lambda); } a.set(fieldLength - 1, fieldLength - 1, precision * (1 - lambda + lambda * a.get(fieldLength - 1, fieldLength - 1))); return a; } protected void storeState() { super.storeState(); System.arraycopy(coalescentIntervals, 0, storedCoalescentIntervals, 0, coalescentIntervals.length); System.arraycopy(sufficientStatistics, 0, storedSufficientStatistics, 0, sufficientStatistics.length); storedWeightMatrix = weightMatrix.copy(); storedLogFieldLikelihood = logFieldLikelihood; } protected void restoreState() { super.restoreState(); // TODO Just swap pointers System.arraycopy(storedCoalescentIntervals, 0, coalescentIntervals, 0, storedCoalescentIntervals.length); System.arraycopy(storedSufficientStatistics, 0, sufficientStatistics, 0, storedSufficientStatistics.length); weightMatrix = storedWeightMatrix; logFieldLikelihood = storedLogFieldLikelihood; } protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { likelihoodKnown = false; // Parameters (precision and popsizes do not change intervals or GMRF Q matrix } /** * Calculates the log likelihood of this set of coalescent intervals, * given a demographic model. * * @return coalescent part of density */ protected double calculateLogCoalescentLikelihood() { makeIntervalsKnown(); // Matrix operations taken from block update sampler to calculate data likelihood and field prior double currentLike = 0; double[] currentGamma = popSizeParameter.getParameterValues(); for (int i = 0; i < fieldLength; i++) { currentLike += -currentGamma[i] - sufficientStatistics[i] * Math.exp(-currentGamma[i]); } return currentLike;// + LogNormalDistribution.logPdf(Math.exp(popSizeParameter.getParameterValue(coalescentIntervals.length - 1)), mu, sigma); } protected double calculateLogFieldLikelihood() { makeIntervalsKnown(); double currentLike = 0; DenseVector diagonal1 = new DenseVector(fieldLength); DenseVector currentGamma = new DenseVector(popSizeParameter.getParameterValues()); SymmTridiagMatrix currentQ = getScaledWeightMatrix(precisionParameter.getParameterValue(0), lambdaParameter.getParameterValue(0)); currentQ.mult(currentGamma, diagonal1); // currentLike += 0.5 * logGeneralizedDeterminant(currentQ) - 0.5 * currentGamma.dot(diagonal1); currentLike += 0.5 * (fieldLength - 1) * Math.log(precisionParameter.getParameterValue(0)) - 0.5 * currentGamma.dot(diagonal1); if (lambdaParameter.getParameterValue(0) == 1) { currentLike -= (fieldLength - 1) / 2.0 * LOG_TWO_TIMES_PI; } else { currentLike -= fieldLength / 2.0 * LOG_TWO_TIMES_PI; } return currentLike; } // public static double logGeneralizedDeterminant(SymmTridiagMatrix X) { // //Set up the eigenvalue solver // SymmTridiagEVD eigen = new SymmTridiagEVD(X.numRows(), false); // //Solve for the eigenvalues // try { // eigen.factor(X); // } catch (NotConvergedException e) { // throw new RuntimeException("Not converged error in generalized determinate calculation.\n" + e.getMessage()); // } // // //Get the eigenvalues // double[] x = eigen.getEigenvalues(); // // double a = 0; // for (double d : x) { // if (d > 0.00001) // a += Math.log(d); // } // // return a; // } public Parameter getPrecisionParameter() { return precisionParameter; } public Parameter getPopSizeParameter() { return popSizeParameter; } public Parameter getLambdaParameter() { return lambdaParameter; } public SymmTridiagMatrix getWeightMatrix() { return weightMatrix.copy(); } public Parameter getBetaParameter() { return betaParameter; } public MatrixParameter getDesignMatrix() { return dMatrix; } public double calculateWeightedSSE() { double weightedSSE = 0; double currentPopSize = popSizeParameter.getParameterValue(0); double currentInterval = coalescentIntervals[0]; for (int j = 1; j < fieldLength; j++) { double nextPopSize = popSizeParameter.getParameterValue(j); double nextInterval = coalescentIntervals[j]; double delta = nextPopSize - currentPopSize; double weight = (currentInterval + nextInterval) / 2.0; weightedSSE += delta * delta / weight; currentPopSize = nextPopSize; currentInterval = nextInterval; } return weightedSSE; } @Override public Citation.Category getCategory() { return Citation.Category.TREE_PRIORS; } @Override public String getDescription() { return "Skyride coalescent"; } @Override public List<Citation> getCitations() { return Collections.singletonList(CITATION); } public static Citation CITATION = new Citation( new Author[]{ new Author("VN", "Minin"), new Author("EW", "Bloomquist"), new Author("MA", "Suchard") }, "Smooth skyride through a rough skyline: Bayesian coalescent-based inference of population dynamics", 2008, "Mol Biol Evol", 25, 1459, 1471, "10.1093/molbev/msn090" ); } /* WinBUGS code to fixed tree: (A:4.0,(B:2.0,(C:0.5,D:1.0):1.0):2.0) model { stat1 ~ dexp(rate[1]) stat2 ~ dexp(rate[2]) stat3 ~ dexp(rate[3]) rate[1] <- 1 / exp(theta[1]) rate[2] <- 1 / exp(theta[2]) rate[3] <- 1 / exp(theta[3]) theta[1] ~ dnorm(0, 0.001) theta[2] ~ dnorm(theta[1], weight[1]) theta[3] ~ dnorm(theta[2], weight[2]) weight[1] <- tau / 1.0 weight[2] <- tau / 1.5 tau ~ dgamma(1,0.3333) stat1 <- 9 / 2 stat2 <- 6 / 2 stat3 <- 4 / 2 } */