/*
* SemiConjugateMultivariateTraitLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.continuous;
import dr.evolution.tree.MultivariateTraitTree;
import dr.evolution.tree.NodeRef;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Model;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import java.util.List;
/**
* Integrated multivariate trait likelihood that assumes a semi-conjugate prior on the root.
* The semi-conjugate prior is a multivariate normal distribution with an independent precision
*
* @author Marc A. Suchard
*/
public class SemiConjugateMultivariateTraitLikelihood extends IntegratedMultivariateTraitLikelihood {
public SemiConjugateMultivariateTraitLikelihood(String traitName,
MultivariateTraitTree treeModel,
MultivariateDiffusionModel diffusionModel,
CompoundParameter traitParameter,
List<Integer> missingIndices,
boolean cacheBranches,
boolean scaleByTime,
boolean useTreeLength,
BranchRateModel rateModel,
Model samplingDensity,
boolean reportAsMultivariate,
MultivariateNormalDistribution rootPrior,
boolean reciprocalRates,
List<RestrictedPartials> partials) {
super(traitName, treeModel, diffusionModel, traitParameter, null, missingIndices, cacheBranches, scaleByTime,
useTreeLength, rateModel, null, null, null, samplingDensity, partials, reportAsMultivariate, reciprocalRates);
setRootPrior(rootPrior); // Semi-conjugate multivariate normal with own mean and precision
}
@Override
public boolean getComputeWishartSufficientStatistics() {
return false; // No need for outer products, as Gibbs sampling of diffusion matrix is not possible
}
protected double calculateAscertainmentCorrection(int taxonIndex) {
throw new RuntimeException("Ascertainment correction not yet implemented for semi-conjugate trait likelihoods");
}
public double getRescaledLengthToRoot(NodeRef node) {
double length = 0;
final NodeRef root = treeModel.getRoot();
while (node != root) {
length += getRescaledBranchLengthForPrecision(node);
node = treeModel.getParent(node);
}
return length;
}
protected double integrateLogLikelihoodAtRoot(double[] y,
double[] Ay,
double[][] AplusB,
double[][] treePrecision, double rootPrecision) {
double detAplusB = 0;
double square = 0;
// square : (Ay + Bz)' (A+B)^{-1} (Ay + Bz)
if (dimTrait > 1) {
for (int i = 0; i < dimTrait; i++) {
Ay[i] += Bz[i]; // Ay is filled with sum, and original value is destroyed
for (int j = 0; j < dimTrait; j++) {
AplusB[i][j] = treePrecision[i][j] * rootPrecision + rootPriorPrecision[i][j];
}
}
Matrix mat = new Matrix(AplusB);
try {
detAplusB = mat.determinant();
} catch (IllegalDimension illegalDimension) {
illegalDimension.printStackTrace();
}
double[][] invAplusB = mat.inverse().toComponents();
for (int i = 0; i < dimTrait; i++) {
for (int j = 0; j < dimTrait; j++)
square += Ay[i] * invAplusB[i][j] * Ay[j];
}
} else {
// 1D is very simple
detAplusB = treePrecision[0][0] * rootPrecision + rootPriorPrecision[0][0];
Ay[0] += Bz[0];
square = Ay[0] * Ay[0] / detAplusB;
}
double retValue = 0.5 * (logRootPriorPrecisionDeterminant - Math.log(detAplusB) - zBz + square);
if (DEBUG) {
System.err.println("(Ay+Bz)(A+B)^{-1}(Ay+Bz) = " + square);
System.err.println("density = " + retValue);
System.err.println("zBz = " + zBz);
}
return retValue;
}
private void setRootPriorSumOfSquares() {
if (integrateRoot) {
Bz = new double[dimTrait];
// z'Bz -- sum-of-squares root contribution
zBz = computeWeightedAverageAndSumOfSquares(rootPriorMean, Bz, rootPriorPrecision, dimTrait, 1.0);
} else {
zBz = 0;
}
}
private void setRootPrior(MultivariateNormalDistribution rootPrior) {
rootPriorMean = rootPrior.getMean();
rootPriorPrecision = rootPrior.getScaleMatrix();
try {
logRootPriorPrecisionDeterminant = Math.log(new Matrix(rootPriorPrecision).determinant());
} catch (IllegalDimension illegalDimension) {
illegalDimension.printStackTrace();
}
setRootPriorSumOfSquares();
}
protected double[][] computeMarginalRootMeanAndVariance(double[] rootMean, double[][] treePrecision,
double[][] treeVariance, double rootPrecision) {
computeWeightedAverageAndSumOfSquares(rootMean, Ay, treePrecision, dimTrait, rootPrecision); // Fills in Ay
double[][] AplusB = tmpM;
for (int i = 0; i < dimTrait; i++) {
Ay[i] += Bz[i]; // Ay is filled with sum, and original value is destroyed
for (int j = 0; j < dimTrait; j++) {
AplusB[i][j] = treePrecision[i][j] * rootPrecision + rootPriorPrecision[i][j];
}
}
Matrix mat = new Matrix(AplusB);
double[][] invAplusB = mat.inverse().toComponents();
// Expected value: (A + B)^{-1}(Ay + Bz)
for (int i = 0; i < dimTrait; i++) {
rootMean[i] = 0.0;
for (int j = 0; j < dimTrait; j++) {
rootMean[i] += invAplusB[i][j] * Ay[j];
}
}
return invAplusB;
}
protected double[] rootPriorMean;
protected double[][] rootPriorPrecision;
protected double logRootPriorPrecisionDeterminant;
protected double[] Bz;
private double zBz; // Prior sum-of-squares contribution
}