/*
* MultivariateTraitDebugUtilities.java
*
* Copyright (c) 2002-2017 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.treedatalikelihood.continuous;
import dr.evolution.tree.MultivariateTraitTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.TaxonList;
import dr.evomodel.continuous.FullyConjugateMultivariateTraitLikelihood;
import dr.evomodel.continuous.RestrictedPartials;
import dr.math.KroneckerOperation;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.WishartSufficientStatistics;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.Vector;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
/**
* @author Marc A. Suchard
*/
public class MultivariateTraitDebugUtilities {
public static double getLengthToRoot(final Tree tree, final NodeRef nodeRef) {
double length = 0;
if (!tree.isRoot(nodeRef)) {
NodeRef parent = tree.getParent(nodeRef);
length += tree.getBranchLength(nodeRef) + getLengthToRoot(tree, parent);
}
return length;
}
private static NodeRef findMRCA(final Tree tree, final int iTip, final int jTip) {
Set<String> leafNames = new HashSet<String>();
leafNames.add(tree.getTaxonId(iTip));
leafNames.add(tree.getTaxonId(jTip));
return TreeUtils.getCommonAncestorNode(tree, leafNames);
}
public static double[][] getTreeVariance(final Tree tree, final double normalization, final double priorSampleSize) {
final int tipCount = tree.getExternalNodeCount();
int length = tipCount;
double[][] variance = new double[length][length];
for (int i = 0; i < tipCount; i++) {
// Fill in diagonal
double marginalTime = getLengthToRoot(tree, tree.getExternalNode(i)) * normalization;
variance[i][i] = marginalTime;
// Fill in upper right triangle,
for (int j = i + 1; j < tipCount; j++) {
NodeRef mrca = findMRCA(tree, i, j);
variance[i][j] = getLengthToRoot(tree, mrca);
}
}
// if (DO_CLAMP && nodeToClampMap != null) {
// List<RestrictedPartials> partialsList = new ArrayList<RestrictedPartials>();
// for (Map.Entry<NodeRef, RestrictedPartials> keySet : nodeToClampMap.entrySet()) {
// partialsList.add(keySet.getValue());
// }
//
// for (int i = 0; i < partialsList.size(); ++i) {
// RestrictedPartials partials = partialsList.get(i);
// NodeRef node = partials.getNode();
//
// variance[tipCount + i][tipCount + i] = getRescaledLengthToRoot(node) +
// 1.0 / partials.getPriorSampleSize();
//
// for (int j = 0; j < tipCount; ++j) {
// NodeRef friend = treeModel.getExternalNode(j);
// NodeRef mrca = Tree.Utils.getCommonAncestor(treeModel, node, friend);
// variance[j][tipCount + i] = getRescaledLengthToRoot(mrca);
//
// }
//
// for (int j = 0; j < i; ++j) {
// NodeRef friend = partialsList.get(j).getNode();
// NodeRef mrca = Tree.Utils.getCommonAncestor(treeModel, node, friend);
// variance[tipCount + j][tipCount + i] = getRescaledLengthToRoot(mrca);
// }
// }
// }
// Make symmetric
for (int i = 0; i < length; i++) {
for (int j = i + 1; j < length; j++) {
variance[j][i] = variance[i][j];
}
}
if (!Double.isInfinite(priorSampleSize)) {
for (int i = 0; i < variance.length; ++i) {
for (int j = 0; j < variance[i].length; ++j) {
variance[i][j] += 1.0 / priorSampleSize;
}
}
}
return variance;
}
// public String getDebugInformation(final FullyConjugateMultivariateTraitLikelihood traitLikelihood) {
//
// StringBuilder sb = new StringBuilder();
//// sb.append(this.g)
//// System.err.println("Hello");
// sb.append("Tree:\n");
// sb.append(traitLikelihood.getId()).append("\t");
//
// final MultivariateTraitTree treeModel = traitLikelihood.getTreeModel();
// sb.append(treeModel.toString());
// sb.append("\n\n");
//
// double[][] treeVariance = computeTreeVariance(true);
// double[][] traitPrecision = traitLikelihood.getDiffusionModel().getPrecisionmatrix();
// Matrix traitVariance = new Matrix(traitPrecision).inverse();
//
// double[][] jointVariance = KroneckerOperation.product(treeVariance, traitVariance.toComponents());
//
// sb.append("Tree variance:\n");
// sb.append(new Matrix(treeVariance));
// sb.append(matrixMin(treeVariance)).append("\t").append(matrixMax(treeVariance)).append("\t").append(matrixSum(treeVariance));
// sb.append("\n\n");
// sb.append("Trait variance:\n");
// sb.append(traitVariance);
// sb.append("\n\n");
//// sb.append("Joint variance:\n");
//// sb.append(new Matrix(jointVariance));
//// sb.append("\n\n");
//
// sb.append("Tree dim: " + treeVariance.length + "\n");
// sb.append("data dim: " + jointVariance.length);
// sb.append("\n\n");
//
// double[] data = new double[jointVariance.length];
// System.arraycopy(meanCache, 0, data, 0, jointVariance.length);
//
// if (nodeToClampMap != null) {
// int offset = treeModel.getExternalNodeCount() * getDimTrait();
// for(Map.Entry<NodeRef, RestrictedPartials> clamps : nodeToClampMap.entrySet()) {
// double[] partials = clamps.getValue().getPartials();
// for (int i = 0; i < partials.length; ++i) {
// data[offset] = partials[i];
// ++offset;
// }
// }
// }
//
// sb.append("Data:\n");
// sb.append(new Vector(data)).append("\n");
// sb.append(data.length).append("\t").append(vectorMin(data)).append("\t").append(vectorMax(data)).append("\t").append(vectorSum(data));
// sb.append(treeModel.getNodeTaxon(treeModel.getExternalNode(0)).getId());
// sb.append("\n\n");
//
// MultivariateNormalDistribution mvn = new MultivariateNormalDistribution(new double[data.length], new Matrix(jointVariance).inverse().toComponents());
// double logDensity = mvn.logPdf(data);
// sb.append("logLikelihood: " + getLogLikelihood() + " == " + logDensity + "\n\n");
//
// final WishartSufficientStatistics sufficientStatistics = getWishartStatistics();
// final double[] outerProducts = sufficientStatistics.getScaleMatrix();
//
// sb.append("Outer-products (DP):\n");
// sb.append(new Vector(outerProducts));
// sb.append(sufficientStatistics.getDf() + "\n");
//
// Matrix treePrecision = new Matrix(treeVariance).inverse();
// final int n = data.length / traitPrecision.length;
// final int p = traitPrecision.length;
// double[][] tmp = new double[n][p];
//
// for (int i = 0; i < n; ++i) {
// for (int j = 0; j < p; ++j) {
// tmp[i][j] = data[i * p + j];
// }
// }
// Matrix y = new Matrix(tmp);
//
// Matrix S = null;
// try {
// S = y.transpose().product(treePrecision).product(y); // Using Matrix-Normal form
// } catch (IllegalDimension illegalDimension) {
// illegalDimension.printStackTrace();
// }
// sb.append("Outer-products (from tree variance:\n");
// sb.append(S);
// sb.append("\n\n");
//
// return sb.toString();
// }
}