/* * DebugableIntegratedMultivariateTraitLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.continuous; import dr.evolution.tree.MultivariateTraitTree; import dr.evolution.tree.NodeRef; import dr.evolution.tree.TreeUtils; import dr.evomodel.branchratemodel.BranchRateModel; import dr.inference.model.CompoundParameter; import dr.inference.model.Model; import dr.math.KroneckerOperation; import dr.math.distributions.MultivariateNormalDistribution; import dr.math.matrixAlgebra.Matrix; import dr.math.matrixAlgebra.SymmetricMatrix; import dr.math.matrixAlgebra.Vector; import java.util.HashSet; import java.util.List; import java.util.Set; /** * This class contains all of the O(number of tips^2) functions to debug the dynamic programming in its superclasses * * @author Marc A Suchard */ public class DebugableIntegratedMultivariateTraitLikelihood extends SemiConjugateMultivariateTraitLikelihood { public DebugableIntegratedMultivariateTraitLikelihood(String traitName, MultivariateTraitTree treeModel, MultivariateDiffusionModel diffusionModel, CompoundParameter traitParameter, List<Integer> missingIndices, boolean cacheBranches, boolean scaleByTime, boolean useTreeLength, BranchRateModel rateModel, Model samplingDensity, boolean reportAsMultivariate, MultivariateNormalDistribution rootPrior, boolean reciprocalRates, List<RestrictedPartials> partials) { super(traitName, treeModel, diffusionModel, traitParameter, missingIndices, cacheBranches, scaleByTime, useTreeLength, rateModel, samplingDensity, reportAsMultivariate, rootPrior, reciprocalRates, partials); } protected double[] fillLeafTraits(int datum) { final int tipCount = treeModel.getExternalNodeCount(); final int nonMissingTipCount = countNonMissingTips(); double[] traits = new double[dimTrait * nonMissingTipCount]; int index = 0; for (int i = 0; i < tipCount; i++) { if (!missingTraits.isCompletelyMissing(i)) { for (int k = 0; k < dimTrait; k++) { traits[index++] = meanCache[dim * i + datum * dimTrait + k]; } } } return traits; } protected double[][] removeMissingTipsInTreeVariance(double[][] variance) { final int tipCount = treeModel.getExternalNodeCount(); final int nonMissing = countNonMissingTips(); if (nonMissing == tipCount) { // Do nothing return variance; } double[][] outVariance = new double[nonMissing][nonMissing]; int iReal = 0; for (int i = 0; i < tipCount; i++) { if (!missingTraits.isCompletelyMissing(i)) { int jReal = 0; for (int j = 0; j < tipCount; j++) { if (!missingTraits.isCompletelyMissing(i)) { outVariance[iReal][jReal] = variance[i][j]; jReal++; } } iReal++; } } return outVariance; } protected double[][] computeTreeTraitPrecision(double[][] traitPrecision) { double[][] treePrecision = computeTreePrecision(); if (dimTrait > 1) { treePrecision = KroneckerOperation.product(treePrecision, traitPrecision); } else { final double precision = traitPrecision[0][0]; for (int i = 0; i < treePrecision.length; i++) { for (int j = 0; j < treePrecision[i].length; j++) { treePrecision[i][j] *= precision; } } } return treePrecision; } public double[][] computeTreePrecision() { return new SymmetricMatrix(computeTreeVariance()).inverse().toComponents(); } private NodeRef findMRCA(int iTip, int jTip) { Set<String> leafNames = new HashSet<String>(); leafNames.add(treeModel.getTaxonId(iTip)); leafNames.add(treeModel.getTaxonId(jTip)); return TreeUtils.getCommonAncestorNode(treeModel, leafNames); } public int getNumberOfDatum() { return numData * countNonMissingTips(); } protected double integrateLogLikelihoodAtRootFromFullTreeMatrix(double[][] treeTraitPrecisionMatrix, double[] tipTraits) { double logLikelihood = 0; final int tipCount = countNonMissingTips(); // 1^t\Sigma^{-1} y + Pz double[] mean = Ay; for (int i = 0; i < dimTrait; i++) { mean[i] = 0; for (int j = 0; j < dimTrait; j++) { mean[i] += rootPriorPrecision[i][j] * rootPriorMean[j]; } for (int j = 0; j < tipCount; j++) { final int rowOffset = j * dimTrait + i; for (int k = 0; k < tipCount * dimTrait; k++) { mean[i] += treeTraitPrecisionMatrix[rowOffset][k] * tipTraits[k]; } } } // 1^t \Sigma^{-1} 1 + P double[][] precision = tmpM; for (int i = 0; i < dimTrait; i++) { for (int j = 0; j < dimTrait; j++) { precision[i][j] = rootPriorPrecision[i][j]; for (int k = 0; k < tipCount; k++) { for (int l = 0; l < tipCount; l++) { precision[i][j] += treeTraitPrecisionMatrix[k * dimTrait + i][l * dimTrait + j]; } } } } double[] normalizedMean = tmp2; double[][] variance = new SymmetricMatrix(precision).inverse().toComponents(); for (int i = 0; i < dimTrait; i++) { normalizedMean[i] = 0.0; for (int j = 0; j < dimTrait; j++) { normalizedMean[i] += variance[i][j] * mean[j]; } } mean = normalizedMean; // We know: y ~ MVN(x, A) and x ~ N(m, B) // Therefore p(x | y) = N( (A+B)^{-1}(Ay + Bm), A + B) // We want: p( y ) = p( y | x ) p( x ) / p( x | y ) for any value x, say x = 0 logLikelihood += MultivariateNormalDistribution.logPdf( rootPriorMean, new double[rootPriorMean.length], rootPriorPrecision, logRootPriorPrecisionDeterminant, 1.0 ); logLikelihood -= MultivariateNormalDistribution.logPdf( mean, new double[mean.length], precision, Math.log(MultivariateNormalDistribution.calculatePrecisionMatrixDeterminate(precision)), 1.0 ); if (DEBUG) { System.err.println("Mean = " + new Vector(mean)); System.err.println("Prec = " + new Matrix(precision)); System.err.println("log density = " + logLikelihood); } return logLikelihood; } public double[][] computeTreeVariance() { final int tipCount = treeModel.getExternalNodeCount(); double[][] variance = new double[tipCount][tipCount]; for (int i = 0; i < tipCount; i++) { // Fill in diagonal double marginalTime = getRescaledLengthToRoot(treeModel.getExternalNode(i)); variance[i][i] = marginalTime; // Fill in upper right triangle, for (int j = i + 1; j < tipCount; j++) { NodeRef mrca = findMRCA(i, j); variance[i][j] = getRescaledLengthToRoot(mrca); } } // Make symmetric for (int i = 0; i < tipCount; i++) { for (int j = i + 1; j < tipCount; j++) { variance[j][i] = variance[i][j]; } } if (DEBUG) { System.err.println(""); System.err.println("New tree conditional variance:\n" + new Matrix(variance)); } variance = removeMissingTipsInTreeVariance(variance); // Automatically prune missing tips if (DEBUG) { System.err.println(""); System.err.println("New tree (trimmed) conditional variance:\n" + new Matrix(variance)); } return variance; } protected int countNonMissingTips() { int tipCount = treeModel.getExternalNodeCount(); for (int i = 0; i < tipCount; i++) { if (missingTraits.isCompletelyMissing(i)) { tipCount--; } } return tipCount; } public void checkViaLargeMatrixInversion() { // Perform a check based on filling in the (dimTrait * tipCount) * (dimTrait * tipCount) precision matrix // And then integrating out the root trait value // Form \Sigma^{-1} (precision) = (tree precision) %x% (trait precision) double[][] treeTraitPrecisionMatrix = computeTreeTraitPrecision(diffusionModel.getPrecisionmatrix()); double totalLogDensity = 0; for (int datum = 0; datum < numData; datum++) { double[] tipTraits = fillLeafTraits(datum); System.err.println("Datum #" + datum); System.err.println("tipTraits = " + new Vector(tipTraits)); System.err.println("tipPrecision = \n" + new Matrix(treeTraitPrecisionMatrix)); double checkLogLikelihood = MultivariateNormalDistribution.logPdf(tipTraits, new double[tipTraits.length], treeTraitPrecisionMatrix, Math.log(MultivariateNormalDistribution.calculatePrecisionMatrixDeterminate(treeTraitPrecisionMatrix)), 1.0); System.err.println("tipDensity = " + checkLogLikelihood + " (should match final likelihood when root not integrated out and no missing data)"); // Convolve root prior if (integrateRoot) { checkLogLikelihood += integrateLogLikelihoodAtRootFromFullTreeMatrix(treeTraitPrecisionMatrix, tipTraits); } totalLogDensity += checkLogLikelihood; } System.err.println("Total logLikelihood (via tree) = " + totalLogDensity); } private double[][] computeTipTraitOuterProduct(int tip0, int tip1) { double[][] outerProduct = new double[dimTrait][dimTrait]; final int offset0 = dim * tip0; final int offset1 = dim * tip1; for (int i = 0; i < dimTrait; i++) { for (int j = 0; j < dimTrait; j++) { for (int k = 0; k < numData; k++) { outerProduct[i][j] += meanCache[offset0 + k * dimTrait + i] * meanCache[offset1 + k * dimTrait + j]; } } } return outerProduct; } private void computeAllTipTraitOuterProducts() { final int nTips = treeModel.getExternalNodeCount(); if (tipTraitOuterProducts == null) { tipTraitOuterProducts = new double[nTips][nTips][][]; } for (int i = 0; i < nTips; i++) { if (!missingTraits.isCompletelyMissing(i)) { tipTraitOuterProducts[i][i] = computeTipTraitOuterProduct(i, i); for (int j = i + 1; j < nTips; j++) { if (!missingTraits.isCompletelyMissing(i)) { tipTraitOuterProducts[j][i] = tipTraitOuterProducts[i][j] = computeTipTraitOuterProduct(i, j); } else { tipTraitOuterProducts[j][i] = tipTraitOuterProducts[i][j] = null; } } } else { for (int j = 0; j < nTips; j++) { tipTraitOuterProducts[i][j] = null; } } } } // Returns the outer product of the tip traits for taxon 0 and taxon 1, // or null if either taxon 0 or taxon 1 is missing public double[][] getTipTraitOuterProduct(int tip0, int tip1) { if (updateOuterProducts) { computeAllTipTraitOuterProducts(); updateOuterProducts = false; } return tipTraitOuterProducts[tip0][tip1]; } protected boolean updateOuterProducts = true; protected double[][][][] tipTraitOuterProducts = null; }