/* * NonPhylogeneticMultivariateTraitLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.continuous; import dr.evolution.tree.MultivariateTraitTree; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evomodel.branchratemodel.BranchRateModel; import dr.inference.model.CompoundParameter; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.math.distributions.WishartSufficientStatistics; import dr.math.matrixAlgebra.Matrix; import dr.math.matrixAlgebra.Vector; import dr.util.Author; import dr.util.Citable; import dr.util.Citation; import java.util.ArrayList; import java.util.List; import java.util.logging.Logger; /** * Integrated multivariate trait likelihood that assumes a fully-conjugate prior on the root and * no underlying tree structure. * * @author Gabriela Cybis * @author Marc A. Suchard * @author Bridgett vonHoldt */ public class NonPhylogeneticMultivariateTraitLikelihood extends FullyConjugateMultivariateTraitLikelihood { public NonPhylogeneticMultivariateTraitLikelihood(String traitName, MultivariateTraitTree treeModel, MultivariateDiffusionModel diffusionModel, CompoundParameter traitParameter, Parameter deltaParameter, List<Integer> missingIndices, boolean cacheBranches, boolean scaleByTime, boolean useTreeLength, BranchRateModel rateModel, Model samplingDensity, boolean reportAsMultivariate, double[] rootPriorMean, double rootPriorSampleSize, List<RestrictedPartials> partials, boolean reciprocalRates, boolean exchangeableTips) { super(traitName, treeModel, diffusionModel, traitParameter, deltaParameter, missingIndices, cacheBranches, scaleByTime, useTreeLength, rateModel, null, null, null, samplingDensity, reportAsMultivariate, rootPriorMean, partials, rootPriorSampleSize, reciprocalRates); this.exchangeableTips = exchangeableTips; this.zeroHeightTip = findZeroHeightTip(treeModel); printInformtion2(); } private int findZeroHeightTip(Tree tree) { for (int i = 0; i < tree.getExternalNodeCount(); ++i) { NodeRef tip = tree.getExternalNode(i); if (tree.getNodeHeight(tip) == 0.0) { return i; } } return -1; } protected void printInformtion() { // Do nothing yet } protected void printInformtion2() { StringBuilder sb = new StringBuilder("Creating non-phylogenetic multivariate diffusion model:\n"); sb.append("\tTrait: ").append(traitName).append("\n"); sb.append("\tDiffusion process: ").append(diffusionModel.getId()).append("\n"); sb.append("\tExchangeable tips: ").append((exchangeableTips ? "yes" : "no")); if (exchangeableTips) { sb.append(" initial inverse-weight = ").append(1.0 / getLengthToRoot(treeModel.getExternalNode(0))); } sb.append("\n"); sb.append(extraInfo()); sb.append("\tPlease cite:\n"); sb.append(Citable.Utils.getCitationString(this)); sb.append("\n\tDiffusion dimension : ").append(dimTrait).append("\n"); sb.append( "\tNumber of observations: ").append(numData).append("\n"); Logger.getLogger("dr.evomodel").info(sb.toString()); } protected double getTreeLength() { double treeLength = 0; double rootHeight = treeModel.getNodeHeight(treeModel.getRoot()); treeLength = 0; for (int i = 0; i < treeModel.getExternalNodeCount(); i++) { NodeRef node = treeModel.getExternalNode(i); treeLength += rootHeight - treeModel.getNodeHeight(node); // Bug was here } return treeLength; } private class SufficientStatistics { double sumWeight; double productWeight; double innerProduct; int nonMissingTips; SufficientStatistics(double sumWeight, double productWeight, double innerProduct, int nonMissingTips) { this.sumWeight = sumWeight; this.productWeight = productWeight; this.innerProduct = innerProduct; this.nonMissingTips = nonMissingTips; } } protected double getLengthToRoot(NodeRef nodeRef) { final double height; if (exchangeableTips) { height = getRescaledLengthToRoot(treeModel.getExternalNode(zeroHeightTip)); } else { height = getRescaledLengthToRoot(nodeRef); } return height; } // Useful identity for computing outerproducts for Wishart statistics // \sum (y_i - \bar{y}) (y_i - \bar{y})^{t} = \sum y_i y_i^{t} - n \bar{y} \bar{y}^t private SufficientStatistics computeInnerProductsForTips(double[][] traitPrecision, double[] tmpVector) { // Compute the contribution of each datum at the root final int rootIndex = treeModel.getRoot().getNumber(); final int meanOffset = dim * rootIndex; // Zero-out root mean for (int d = 0; d < dim; ++d) { meanCache[meanOffset + d] = 0; } double innerProducts = 0.0; // Compute the contribution of each datum at the root double productWeight = 1.0; double sumWeight = 0.0; int nonMissingTips = 0; for (int i = 0; i < treeModel.getExternalNodeCount(); ++i) { NodeRef tipNode = treeModel.getExternalNode(i); final int tipNumber = tipNode.getNumber(); double tipWeight = 0.0; if (!missingTraits.isCompletelyMissing(tipNumber)) { tipWeight = 1.0 / getLengthToRoot(tipNode); int tipOffset = dim * tipNumber; int rootOffset = dim * rootIndex; for (int datum = 0; datum < numData; ++datum) { // TODO Make faster when dimTrait == 1 // Add weighted tip value for (int d = 0; d < dimTrait; ++d) { meanCache[rootOffset + d] += tipWeight * meanCache[tipOffset + d]; tmpVector[d] = meanCache[tipOffset + d]; } // Compute outer product double yAy = computeWeightedAverageAndSumOfSquares(tmpVector, Ay, traitPrecision, dimTrait, tipWeight); innerProducts += yAy; if (DEBUG_NO_TREE) { System.err.println("OP for " + tipNumber + " = " + yAy); System.err.println("Value = " + new Vector(tmpVector)); System.err.print ("Prec =\n" + new Matrix(traitPrecision)); System.err.println("weight = " + tipWeight + "\n"); } tipOffset += dimTrait; rootOffset += dimTrait; } if (computeWishartStatistics) { incrementOuterProducts(tipNumber, tipWeight); } } if (tipWeight > 0.0) { sumWeight += tipWeight; productWeight *= tipWeight; ++nonMissingTips; } } lowerPrecisionCache[rootIndex] = sumWeight; normalize(meanCache, meanOffset, dim, sumWeight); if (computeWishartStatistics) { incrementOuterProducts(rootIndex, -sumWeight); wishartStatistics.incrementDf(-1); } return new SufficientStatistics(sumWeight, productWeight, innerProducts, nonMissingTips); } private void normalize(double[] x, int offset, int dim, double weight) { for (int d = 0; d < dim; ++d) { x[offset + d] /= weight; } } private void incrementOuterProducts(int nodeNumber, double nodeWeight) { final double[] outerProduct = wishartStatistics.getScaleMatrix(); int tipOffset = dim * nodeNumber; for (int datum = 0; datum < numData; ++datum) { for (int i = 0; i < dim; ++i) { double yi = meanCache[tipOffset + i]; for (int j = 0; j < dim; ++j) { outerProduct[i * dim + j] += yi * meanCache[tipOffset +j] * nodeWeight; } } tipOffset += dimTrait; } wishartStatistics.incrementDf(1); // Peeled one node } protected boolean peel() { return false; } public double calculateLogLikelihood() { double[][] traitPrecision = diffusionModel.getPrecisionmatrix(); double logDetTraitPrecision = Math.log(diffusionModel.getDeterminantPrecisionMatrix()); double[] marginalRoot = tmp2; if (computeWishartStatistics) { wishartStatistics = new WishartSufficientStatistics(dimTrait); } // Compute the contribution of each datum at the root SufficientStatistics stats = computeInnerProductsForTips(traitPrecision, tmp2); double conditionalSumWeight = stats.sumWeight; double conditionalProductWeight = stats.productWeight; double innerProducts = stats.innerProduct; int nonMissingTips = stats.nonMissingTips; // Add in prior and integrate double sumWeight = conditionalSumWeight + rootPriorSampleSize; double productWeight = conditionalProductWeight * rootPriorSampleSize; double rootPrecision = productWeight / sumWeight; final int rootIndex = treeModel.getRoot().getNumber(); int rootOffset = dim * rootIndex; for (int datum = 0; datum < numData; ++datum) { // Determine marginal root (scaled) mean for (int d = 0; d < dimTrait; ++d) { marginalRoot[d] = conditionalSumWeight * meanCache[rootOffset + d] + rootPriorSampleSize * rootPriorMean[d]; } // Compute outer product contribution from prior double yAy1 = computeWeightedAverageAndSumOfSquares(rootPriorMean, Ay, traitPrecision, dimTrait, rootPriorSampleSize); innerProducts += yAy1; // TODO Only need to compute once if (DEBUG_NO_TREE) { System.err.println("OP for root"); System.err.println("Value = " + new Vector(rootPriorMean)); System.err.print ("Prec = \n" + new Matrix(traitPrecision)); System.err.println("Weight = " + rootPriorSampleSize + "\n"); } // Compute outer product differences to complete square double yAy2 = computeWeightedAverageAndSumOfSquares(marginalRoot, Ay, traitPrecision, dimTrait, 1.0 / sumWeight); innerProducts -= yAy2; // Add prior on root contribution if (computeWishartStatistics) { final double[] outerProducts = wishartStatistics.getScaleMatrix(); final double weight = conditionalSumWeight * rootPriorSampleSize / sumWeight; for (int i = 0; i < dimTrait; i++) { final double diffi = meanCache[rootOffset + i] - rootPriorMean[i]; for (int j = 0; j < dimTrait; j++) { outerProducts[i * dimTrait + j] += diffi * weight * (meanCache[rootOffset + j] - rootPriorMean[j]); } } wishartStatistics.incrementDf(1); } rootOffset += dimTrait; } if (DEBUG_NO_TREE) { System.err.println("SumWeight : " + sumWeight); System.err.println("ProductWeight: " + productWeight); System.err.println("Total OP : " + innerProducts); } // Compute log likelihood double logLikelihood = -LOG_SQRT_2_PI * dimTrait * nonMissingTips * numData + 0.5 * logDetTraitPrecision * nonMissingTips * numData + 0.5 * Math.log(rootPrecision) * dimTrait * numData - 0.5 * innerProducts; if (DEBUG_NO_TREE) { System.err.println("logLikelihood (final) = " + logLikelihood); System.err.println("numData = " + numData); } areStatesRedrawn = false; // Should redraw internal node states when needed return logLikelihood; } private final boolean exchangeableTips; private final int zeroHeightTip; private static final boolean DEBUG_NO_TREE = false; private static final boolean NO_RESCALING = false; }