/*
* PrecisionMatrixGibbsOperator.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.operators;
import dr.evolution.tree.MultivariateTraitTree;
import dr.evolution.tree.NodeRef;
import dr.evomodel.continuous.AbstractMultivariateTraitLikelihood;
import dr.evomodel.continuous.SampledMultivariateTraitLikelihood;
import dr.evomodel.treedatalikelihood.DataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.inference.distribution.MultivariateDistributionLikelihood;
import dr.inference.distribution.MultivariateNormalDistributionModel;
import dr.inference.distribution.WishartGammalDistributionModel;
import dr.inference.model.MatrixParameter;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.distributions.WishartDistribution;
import dr.math.distributions.WishartStatistics;
import dr.math.distributions.WishartSufficientStatistics;
import dr.math.interfaces.ConjugateWishartStatisticsProvider;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.Vector;
import dr.util.Attribute;
import dr.xml.*;
import java.util.List;
//import dr.math.matrixAlgebra.Matrix;
/**
* @author Marc Suchard
*/
public class PrecisionMatrixGibbsOperator extends SimpleMCMCOperator implements GibbsOperator {
public static final String VARIANCE_OPERATOR = "precisionGibbsOperator";
// public static final String PRECISION_MATRIX = "precisionMatrix";
public static final String TREE_MODEL = "treeModel";
// public static final String OUTCOME = "outcome";
// public static final String MEAN = "mean";
public static final String DISTRIBUTION = "distribution";
public static final String PRIOR = "prior";
// public static final String TRAIT_MODEL = "traitModel";
private final AbstractMultivariateTraitLikelihood traitModel;
private AbstractMultivariateTraitLikelihood debugModel = null;
private final ConjugateWishartStatisticsProvider conjugateWishartProvider;
private final MultivariateDistributionLikelihood multivariateLikelihood;
private final Parameter meanParam;
private final MatrixParameterInterface precisionParam;
private double priorDf;
private SymmetricMatrix priorInverseScaleMatrix;
private final MultivariateTraitTree treeModel;
private final int dim;
private double numberObservations;
private final String traitName;
private final boolean isSampledTraitLikelihood;
private double pathWeight = 1.0;
private boolean wishartIsModel = false;
private WishartGammalDistributionModel priorModel = null;
public PrecisionMatrixGibbsOperator(
MultivariateDistributionLikelihood likelihood,
WishartStatistics priorDistribution,
double weight) {
super();
// Unnecessary variables
this.traitModel = null;
this.treeModel = null;
this.traitName = null;
this.conjugateWishartProvider = null;
this.isSampledTraitLikelihood = false;
this.multivariateLikelihood = likelihood;
MultivariateNormalDistributionModel density = (MultivariateNormalDistributionModel) likelihood.getDistribution();
this.meanParam = density.getMeanParameter();
this.precisionParam = density.getPrecisionMatrixParameter();
this.dim = meanParam.getDimension();
setupWishartStatistics(priorDistribution);
if (priorDistribution instanceof WishartGammalDistributionModel) {
wishartIsModel = true;
priorModel = (WishartGammalDistributionModel) priorDistribution;
}
setWeight(weight);
}
private void setupWishartStatistics(WishartStatistics priorDistribution) {
this.priorDf = priorDistribution.getDF();
this.priorInverseScaleMatrix = null;
double[][] scale = priorDistribution.getScaleMatrix();
if (scale != null)
this.priorInverseScaleMatrix =
(SymmetricMatrix) (new SymmetricMatrix(scale)).inverse();
}
@Deprecated
public PrecisionMatrixGibbsOperator(
AbstractMultivariateTraitLikelihood traitModel,
WishartStatistics priorDistribution,
double weight) {
super();
this.traitModel = traitModel;
this.conjugateWishartProvider = null;
this.meanParam = null;
this.precisionParam = (MatrixParameter) traitModel.getDiffusionModel().getPrecisionParameter();
setupWishartStatistics(priorDistribution);
if (priorDistribution instanceof WishartGammalDistributionModel) {
wishartIsModel = true;
priorModel = (WishartGammalDistributionModel) priorDistribution;
}
setWeight(weight);
this.treeModel = traitModel.getTreeModel();
traitName = traitModel.getTraitName();
dim = precisionParam.getRowDimension(); // assumed to be square
isSampledTraitLikelihood = (traitModel instanceof SampledMultivariateTraitLikelihood);
if (!isSampledTraitLikelihood &&
!(traitModel instanceof ConjugateWishartStatisticsProvider)) {
throw new RuntimeException("Only implemented for a SampledMultivariateTraitLikelihood or " +
"ConjugateWishartStatisticsProvider");
}
multivariateLikelihood = null;
}
public PrecisionMatrixGibbsOperator(
ConjugateWishartStatisticsProvider wishartStatisticsProvider,
WishartStatistics priorDistribution,
double weight,
AbstractMultivariateTraitLikelihood debugModel) {
super();
this.traitModel = null;
this.debugModel = debugModel;
this.conjugateWishartProvider = wishartStatisticsProvider;
this.meanParam = null;
this.precisionParam = conjugateWishartProvider.getPrecisionParamter();
isSampledTraitLikelihood = false;
this.treeModel = null;
this.traitName = null;
setupWishartStatistics(priorDistribution);
if (priorDistribution instanceof WishartGammalDistributionModel) {
wishartIsModel = true;
priorModel = (WishartGammalDistributionModel) priorDistribution;
}
setWeight(weight);
// this.treeModel = traitModel.getTreeModel();
// traitName = traitModel.getTraitName();
dim = precisionParam.getRowDimension(); // assumed to be square
// isSampledTraitLikelihood = (traitModel instanceof SampledMultivariateTraitLikelihood);
// if (!isSampledTraitLikelihood &&
// !(traitModel instanceof ConjugateWishartStatisticsProvider)) {
// throw new RuntimeException("Only implemented for a SampledMultivariateTraitLikelihood or " +
// "ConjugateWishartStatisticsProvider");
// }
multivariateLikelihood = null;
}
public void setPathParameter(double beta) {
if (beta < 0 || beta > 1) {
throw new IllegalArgumentException("Illegal path weight of " + beta);
}
pathWeight = beta;
}
public int getStepCount() {
return 1;
}
// private void incrementScaledSquareMatrix(double[][] out, double[][] in, double scalar, int dim) {
// for (int i = 0; i < dim; i++) {
// for (int j = 0; j < dim; j++) {
// out[i][j] += scalar * in[i][j];
// }
// }
// }
// private void zeroSquareMatrix(double[][] out, int dim) {
// for (int i = 0; i < dim; i++) {
// for (int j = 0; j < dim; j++) {
// out[i][j] = 0.0;
// }
// }
// }
private void incrementOuterProduct(double[][] S,
MultivariateDistributionLikelihood likelihood) {
double[] mean = likelihood.getDistribution().getMean();
numberObservations = 0;
List<Attribute<double[]>> dataList = likelihood.getDataList();
int count = 0;
for (Attribute<double[]> d : dataList) {
double[] data = d.getAttributeValue();
for (int i = 0; i < dim; i++) {
data[i] -= mean[i];
}
for (int i = 0; i < dim; i++) { // symmetric matrix,
for (int j = i; j < dim; j++) {
S[j][i] = S[i][j] += data[i] * data[j];
}
}
numberObservations += 1;
}
}
private void incrementOuterProduct(double[][] S,
ConjugateWishartStatisticsProvider integratedLikelihood) {
final WishartSufficientStatistics sufficientStatistics = integratedLikelihood.getWishartStatistics();
final double[] outerProducts = sufficientStatistics.getScaleMatrix();
final double df = sufficientStatistics.getDf();
if (DEBUG) {
System.err.println("OP df = " + df);
System.err.println("OP = " + new Vector(outerProducts));
}
// System.exit(-1);
if (debugModel != null) {
final WishartSufficientStatistics debug = ((ConjugateWishartStatisticsProvider) debugModel).getWishartStatistics();
System.err.println(df + " ?= " + debug.getDf());
System.err.println(new Vector(outerProducts));
System.err.println("");
System.err.println(new Vector(debug.getScaleMatrix()));
System.exit(-1);
}
// final double df = 2;
// final double df = integratedLikelihood.getTotalTreePrecision();
// System.err.println("OuterProducts = \n" + new Matrix(outerProducts));
// System.err.println("Total tree DF = " + df);
// System.exit(-1);
final int dim = S.length;
for (int i = 0; i < dim; i++) {
System.arraycopy(outerProducts, i * dim, S[i], 0, dim);
}
numberObservations = df;
// checkDiagonals(outerProducts);
}
// private void checkDiagonals(double[][] S) {
// for (int i = 0; i < S.length; ++i) {
// if (S[i][i] < 0.0) {
// System.err.println("ERROR diag(S)\n" + new Matrix(S));
// System.exit(-1);
// }
// }
// }
private void incrementOuterProduct(double[][] S, NodeRef node) {
if (!treeModel.isRoot(node)) {
NodeRef parent = treeModel.getParent(node);
double[] parentTrait = treeModel.getMultivariateNodeTrait(parent, traitName);
double[] childTrait = treeModel.getMultivariateNodeTrait(node, traitName);
double time = traitModel.getRescaledBranchLengthForPrecision(node);
if (time > 0) {
double sqrtTime = Math.sqrt(time);
double[] delta = new double[dim];
for (int i = 0; i < dim; i++)
delta[i] = (childTrait[i] - parentTrait[i]) / sqrtTime;
for (int i = 0; i < dim; i++) { // symmetric matrix,
for (int j = i; j < dim; j++)
S[j][i] = S[i][j] += delta[i] * delta[j];
}
numberObservations += 1; // This assumes a *single* observation per tip
}
}
// recurse down tree
for (int i = 0; i < treeModel.getChildCount(node); i++)
incrementOuterProduct(S, treeModel.getChild(node, i));
}
public double[][] getOperationScaleMatrixAndSetObservationCount() {
// calculate sum-of-the-weighted-squares matrix over tree
double[][] S = new double[dim][dim];
SymmetricMatrix S2;
SymmetricMatrix inverseS2 = null;
numberObservations = 0; // Need to reset, as incrementOuterProduct can be recursive
if (isSampledTraitLikelihood) {
incrementOuterProduct(S, treeModel.getRoot());
} else { // IntegratedTraitLikelihood
if (traitModel != null) { // is a tree
incrementOuterProduct(S, (ConjugateWishartStatisticsProvider) traitModel); // TODO deprecate usage
} else if (conjugateWishartProvider != null) {
incrementOuterProduct(S, conjugateWishartProvider);
} else { // is a normal-normal-wishart model
incrementOuterProduct(S, multivariateLikelihood);
}
}
try {
S2 = new SymmetricMatrix(S);
if (pathWeight != 1.0) {
S2 = (SymmetricMatrix) S2.product(pathWeight);
}
if (priorInverseScaleMatrix != null)
S2 = priorInverseScaleMatrix.add(S2);
inverseS2 = (SymmetricMatrix) S2.inverse();
// if (S[0][0] < 0.0) {
// System.err.println("ERROR A");
// System.err.println(new Matrix(S));
// }
//
// if (S2.component(0, 0) < 0.0) {
// System.err.println("ERROR B");
// System.err.println(S2);
// }
//
// if (inverseS2.component(0, 0) < 0.0) {
// System.err.println("ERROR C");
// System.err.println("S:\n" + new Matrix(S));
// System.err.println("S2:\n" + S2);
// System.err.println(inverseS2);
// }
} catch (IllegalDimension illegalDimension) {
illegalDimension.printStackTrace();
}
assert inverseS2 != null;
return inverseS2.toComponents();
}
public double doOperation() {
if (wishartIsModel) {
setupWishartStatistics(priorModel);
}
final double[][] scaleMatrix = getOperationScaleMatrixAndSetObservationCount();
final double treeDf = numberObservations;
final double df = priorDf + treeDf * pathWeight;
// if (scaleMatrix[0][0] < 0.0) {
// System.err.println("ERROR");
// System.err.println(new Matrix(scaleMatrix));
// }
double[][] draw = WishartDistribution.nextWishart(df, scaleMatrix);
// int tries = 0;
// int limit = 100;
// boolean success = false;
//
// double[][] draw = null;
//
// while (!success && tries < limit) {
if (DEBUG) {
System.err.println("draw = " + new Matrix(draw));
}
//
//
// draw = WishartDistribution.nextWishart(df, scaleMatrix);
//
// Matrix m = new Matrix(draw);
// try {
// double logDet = m.logDeterminant();
// if (Double.isNaN(logDet)) {
// System.err.println("Bad proposal!");
// System.err.println("df = " + df);
//
//// System.err.println(m);
//// System.exit(-1);
// } else {
// success = true;
// }
//
// } catch (IllegalDimension illegalDimension) {
// illegalDimension.printStackTrace(); //To change body of catch statement use File | Settings | File Templates.
// System.exit(-1);
// }
//
// tries++;
// }
//
// if (tries >= limit) {
// System.err.println("Too many attempts!");
// System.exit(-1);
// }
for (int i = 0; i < dim; i++) {
Parameter column = precisionParam.getParameter(i);
for (int j = 0; j < dim; j++)
column.setParameterValueQuietly(j, draw[j][i]);
}
precisionParam.fireParameterChangedEvent();
return 0;
}
public String getPerformanceSuggestion() {
return null;
}
public String getOperatorName() {
return VARIANCE_OPERATOR;
}
public static dr.xml.XMLObjectParser PARSER = new dr.xml.AbstractXMLObjectParser() {
public String getParserName() {
return VARIANCE_OPERATOR;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
double weight = xo.getDoubleAttribute(WEIGHT);
AbstractMultivariateTraitLikelihood traitModel = (AbstractMultivariateTraitLikelihood) xo.getChild(AbstractMultivariateTraitLikelihood.class);
ConjugateWishartStatisticsProvider ws = (ConjugateWishartStatisticsProvider) xo.getChild(ConjugateWishartStatisticsProvider.class);
if (ws == traitModel) {
ws = null;
}
MultivariateDistributionLikelihood prior = null;
MatrixParameterInterface precMatrix = null;
MultivariateDistributionLikelihood likelihood = null;
if (traitModel != null) {
precMatrix = (MatrixParameter) traitModel.getDiffusionModel().getPrecisionParameter();
prior = (MultivariateDistributionLikelihood) xo.getChild(MultivariateDistributionLikelihood.class);
}
if (ws != null) {
precMatrix = ws.getPrecisionParamter();
prior = (MultivariateDistributionLikelihood) xo.getChild(MultivariateDistributionLikelihood.class);
}
if (traitModel == null && ws == null) { // generic likelihood and prior
for (int i = 0; i < xo.getChildCount(); ++i) {
MultivariateDistributionLikelihood density = (MultivariateDistributionLikelihood) xo.getChild(i);
if (density.getDistribution() instanceof WishartStatistics) {
prior = density;
} else if (density.getDistribution() instanceof MultivariateNormalDistributionModel) {
likelihood = density;
precMatrix = ((MultivariateNormalDistributionModel)
density.getDistribution()).getPrecisionMatrixParameter();
}
}
if (prior == null || likelihood == null) {
throw new XMLParseException(
"Must provide a multivariate normal likelihood and Wishart prior in element '" +
xo.getName() + "'\n"
);
}
}
if (!(prior.getDistribution() instanceof WishartStatistics)) {
throw new XMLParseException("Only a Wishart distribution is conjugate for Gibbs sampling");
}
// Make sure precMatrix is square and dim(precMatrix) = dim(parameter)
if (precMatrix.getColumnDimension() != precMatrix.getRowDimension()) {
throw new XMLParseException("The variance matrix is not square or of wrong dimension");
}
if (traitModel != null && ws == null) {
return new PrecisionMatrixGibbsOperator(
traitModel, (WishartStatistics) prior.getDistribution(), weight
);
} else if (ws != null) {
return new PrecisionMatrixGibbsOperator(
ws, (WishartStatistics) prior.getDistribution(), weight, traitModel
);
} else {
return new PrecisionMatrixGibbsOperator(likelihood, (WishartStatistics) prior.getDistribution(), weight);
}
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "This element returns a multivariate normal random walk operator on a given parameter.";
}
public Class getReturnType() {
return MCMCOperator.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{
AttributeRule.newDoubleRule(WEIGHT),
new ElementRule(AbstractMultivariateTraitLikelihood.class, true),
new ElementRule(ConjugateWishartStatisticsProvider.class, true),
new ElementRule(MultivariateDistributionLikelihood.class, 1, 2),
};
};
private static final boolean DEBUG = false;
}