/*
* GaussianProcessFromTree.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.continuous;
import dr.evolution.tree.NodeRef;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.KroneckerOperation;
import dr.math.distributions.GaussianProcessRandomGenerator;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;
/**
* @author Marc A. Suchard
* @author Max R. Tolkoff
*/
public class GaussianProcessFromTree implements GaussianProcessRandomGenerator {
private final FullyConjugateMultivariateTraitLikelihood traitModel;
public GaussianProcessFromTree(FullyConjugateMultivariateTraitLikelihood traitModel) {
this.traitModel = traitModel;
}
@Override
public Likelihood getLikelihood() {
return traitModel;
}
@Override
public int getDimension() {
return traitModel.getTreeModel().getExternalNodeCount() * traitModel.getDimTrait();
}
@Override
public double[][] getPrecisionMatrix() {
final boolean includeRoot = false; // TODO make an option
double[][] treeVariance;
// long startTime1 = System.nanoTime();
treeVariance = traitModel.computeTreeVariance2(includeRoot);
// long estimatedTime1 = System.nanoTime() - startTime1;
// long startTime2 = System.nanoTime();
// treeVariance = traitModel.computeTreeVariance(includeRoot);
// long estimatedTime2 = System.nanoTime() - startTime2;
double[][] traitPrecision = traitModel.getDiffusionModel().getPrecisionmatrix();
// for (int i = 0; i < treeVariance2.length; ++i) {
// for (int j = 0; j < treeVariance2[i].length; ++j) {
// if (treeVariance2[i][j] != treeVariance[i][j]) {
// System.err.println(i + " " + j);
// System.err.println(treeVariance2[i][j] + " " + treeVariance[i][j]);
// System.exit(-1);
// }
// }
// }
// System.err.println("T1: " + estimatedTime1);
// System.err.println("T2: " + estimatedTime2);
// System.err.println("\t\tSTART prec");
Matrix treePrecision = new Matrix(treeVariance).inverse();
// System.err.println("\t\tSTART kron");
double[][] jointPrecision = KroneckerOperation.product(treePrecision.toComponents(), traitPrecision); // TODO Double-check order
return jointPrecision;
}
private static void scale(double[][] matrix, double scale) {
for (int i = 0; i < matrix.length; ++i) {
for (int j = 0; j < matrix[i].length; ++j) {
matrix[i][j] *= scale;
}
}
}
public double getLogLikelihood() { return traitModel.getLogLikelihood(); }
// boolean firstTime=true;
public double[] nextRandomFast() {
double[] random = new double[traitModel.getTreeModel().getExternalNodeCount() * traitModel.getDimTrait()];
NodeRef root = traitModel.getTreeModel().getRoot();
double[] traitStart = traitModel.getPriorMean();
double[][] varianceCholesky = null;
double[][] temp = new SymmetricMatrix(traitModel.getDiffusionModel().getPrecisionmatrix()).inverse().toComponents();
try {
varianceCholesky = (new CholeskyDecomposition(temp).getL());
} catch (IllegalDimension illegalDimension) {
illegalDimension.printStackTrace();
}
// if(traitModel.getTreeModel().isExternal(root)) {
// random[0] = traitModel.getTreeModel().getMultivariateNodeTrait(root, traitModel.getTraitName())[i];
// }
// else{
// double[][] var = MultivariateTraitUtils.computeTreeVariance(traitModel, true);
// if(firstTime) {
// for (int j = 0; j < var[0].length; j++) {
// for (int k = 0; k < var[0].length; k++) {
// if(j!=k)
// var[j][k] = var[j][k] / Math.sqrt(var[k][k] * var[j][j]);
// }
// }
//
//
//
// for (int j = 0; j < var[0].length; j++) {
// String empty = "";
// for (int k = 0; k < var[0].length; k++) {
// empty += Double.toString(var[j][k]) + "\t";
// }
// System.out.println(empty);
// }
// firstTime=false;
// }
if (USE_BUFFER) {
final int length = traitModel.getDimTrait();
final int nodeCount = traitModel.getTreeModel().getNodeCount();
double[] currentValue = new double[(nodeCount + 1) * length];
double[] epsilon = new double[length];
final int priorOffset = nodeCount * length;
System.arraycopy(traitStart, 0, currentValue, priorOffset, length);
nextRandomFast2(currentValue, priorOffset, root, random, varianceCholesky, epsilon);
} else {
nextRandomFast(traitStart, root, random, varianceCholesky);
}
// }
return random;
}
private void nextRandomFast(double[] currentValue, NodeRef currentNode, double[] random, double[][] varianceCholesky) {
double rescaledLength = (traitModel.getTreeModel().isRoot(currentNode)) ?
1.0 / traitModel.getPriorSampleSize() :
traitModel.getRescaledBranchLengthForPrecision(currentNode);
double scale = Math.sqrt(rescaledLength);
// draw ~ MNV(mean = currentVale, variance = scale * scale * L^t L)
double[] draw = MultivariateNormalDistribution.nextMultivariateNormalCholesky(currentValue, varianceCholesky, scale);
if (traitModel.getTreeModel().isExternal(currentNode)) {
System.arraycopy(draw, 0, random, currentNode.getNumber() * draw.length, draw.length);
} else {
int childCount = traitModel.getTreeModel().getChildCount(currentNode);
for (int i = 0; i < childCount; i++) {
nextRandomFast(draw, traitModel.getTreeModel().getChild(currentNode, i), random, varianceCholesky);
}
}
}
private void nextRandomFast2(double[] currentValue, int parentOffset, NodeRef currentNode, double[] random,
double[][] varianceCholesky, double[] epsilon) {
final int length = varianceCholesky.length;
double rescaledLength = (traitModel.getTreeModel().isRoot(currentNode)) ?
1.0 / traitModel.getPriorSampleSize() :
traitModel.getRescaledBranchLengthForPrecision(currentNode);
double scale = Math.sqrt(rescaledLength);
final int currentOffset = currentNode.getNumber() * length;
// draw ~ MNV(mean = currentValue at parent, variance = scale * scale * L^t L)
MultivariateNormalDistribution.nextMultivariateNormalCholesky(
currentValue, parentOffset, // mean at parent
varianceCholesky, scale,
currentValue, currentOffset, // result at current
epsilon);
if (traitModel.getTreeModel().isExternal(currentNode)) {
System.arraycopy(
currentValue, currentOffset, // result at tip
random, currentOffset, // into final results buffer
length);
} else {
int childCount = traitModel.getTreeModel().getChildCount(currentNode);
for (int i = 0; i < childCount; i++) {
nextRandomFast2(
currentValue, currentOffset,
traitModel.getTreeModel().getChild(currentNode, i),
random, varianceCholesky, epsilon);
}
}
}
@Override
public Object nextRandom() {
return nextRandomFast();
}
@Override
public double logPdf(Object x) {
double[] v = (double[]) x;
Parameter variable = traitModel.getTraitParameter();
for (int i = 0; i < v.length; ++i) {
variable.setParameterValueQuietly(i, v[i]);
}
variable.fireParameterChangedEvent();
return traitModel.getLogLikelihood();
}
private static final boolean USE_BUFFER = true;
}