/*
* LatentLiabilityGibbs.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
/*
* Gabriela Cybis
*
* Gibbs operator for latent variable in latent liability model
*/
package dr.evomodel.operators;
import dr.evolution.tree.MultivariateTraitTree;
import dr.evolution.tree.NodeRef;
import dr.evomodel.continuous.FullyConjugateMultivariateTraitLikelihood;
import dr.evomodel.continuous.LatentTruncation;
import dr.inference.model.CompoundParameter;
import dr.inference.model.MatrixParameter;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.NormalDistribution;
import dr.xml.*;
import java.util.logging.Logger;
//import dr.evomodel.continuous.MultinomialLatentLiabilityLikelihood;
//import dr.inference.operators.GibbsOperator;
public class LatentLiabilityGibbs extends SimpleMCMCOperator {
// public class LatentLiabilityGibbs extends SimpleMCMCOperator implements GibbsOperator {
public static final String LATENT_LIABILITY_GIBBS_OPERATOR = "latentLiabilityGibbsOperator";
public static final String TREE_MODEL = "treeModel";
private final LatentTruncation latentLiability;
// private final MultinomialLatentLiabilityLikelihood latentLiability;
private final FullyConjugateMultivariateTraitLikelihood traitModel;
private final CompoundParameter tipTraitParameter;
protected double[] rootPriorMean;
protected double rootPriorSampleSize;
private final MatrixParameter precisionParam;
private final MultivariateTraitTree treeModel;
private final int dim;
public double[][] postMeans;
public double[][] preMeans;
public double[] preP;
public double[] postP;
public LatentLiabilityGibbs(
FullyConjugateMultivariateTraitLikelihood traitModel,
LatentTruncation LatentLiability, CompoundParameter tipTraitParameter,
double weight) {
super();
/// Set up all required variables
this.latentLiability = LatentLiability;
this.traitModel = traitModel;
this.tipTraitParameter = tipTraitParameter;
this.rootPriorMean = traitModel.getPriorMean();
this.rootPriorSampleSize = traitModel.getPriorSampleSize();
this.precisionParam = (MatrixParameter) traitModel.getDiffusionModel().getPrecisionParameter();
this.treeModel = traitModel.getTreeModel();
dim = precisionParam.getRowDimension(); // assumed to be square
postMeans = new double[treeModel.getNodeCount()][dim];
preMeans = new double[treeModel.getNodeCount()][dim];
preP = new double[treeModel.getNodeCount()];
postP = new double[treeModel.getNodeCount()];
setWeight(weight);
}
public int getStepCount() {
return 1;
}
private void printInformation(MatrixParameter par) {
StringBuffer sb = new StringBuffer("\n \n parameter \n");
for (int j = 0; j < dim; j++) {
sb.append(par.getParameterValue(0, j));
}
Logger.getLogger("dr.evomodel").info(sb.toString());
}
private void printInformation(double[] par) {
StringBuffer sb = new StringBuffer("\n \n double vector \n");
for (int j = 0; j < treeModel.getNodeCount(); j++) {
sb.append(par[j]);
}
Logger.getLogger("dr.evomodel").info(sb.toString());
}
private void printInformation(double[][] par) {
StringBuffer sb = new StringBuffer("\n \n double matrix \n");
for (int i = 0; i < 1; i++) {
for (int j = 0; j < treeModel.getNodeCount(); j++) {
sb.append(par[j][i]);
}
}
Logger.getLogger("dr.evomodel").info(sb.toString());
}
private void printInformation(double par) {
StringBuffer sb = new StringBuffer("\n \n double \n");
sb.append(par);
Logger.getLogger("dr.evomodel").info(sb.toString());
}
private void printInformation(double par, String lala) {
StringBuffer sb = new StringBuffer("\n \n double \n");
sb.append(lala);
sb.append("\t\t");
sb.append(par);
Logger.getLogger("dr.evomodel").info(sb.toString());
}
public double doOperation() {
// doPostOrderTraversal(treeModel.getRoot());
// doPreOrderTraversal(treeModel.getRoot());
//printInformation(postP);
//printInformation(preP);
//printInformation(postMeans);
//printInformation(preMeans);
final int pos = MathUtils.nextInt(treeModel.getExternalNodeCount());
// final int pos = 1;
NodeRef node = treeModel.getExternalNode(pos);
double logq = sampleNode2(node);
tipTraitParameter.fireParameterChangedEvent();
return logq;
}
//Fill out partial mean and precision values in post order
public void doPostOrderTraversal(NodeRef node) { // TODO This is already computed IntegratedMultivariateTraitLikelihood
final int thisNumber = node.getNumber();
if (treeModel.isExternal(node)) {
// writes trait values and precision values for tips
double[] traitValue = getNodeTrait(node);
for (int j = 0; j < dim; j++) {
postMeans[thisNumber][j] = traitValue[j];
}
postP[thisNumber] = 1 / traitModel.getRescaledBranchLengthForPrecision(node);
return;
}
final NodeRef childNode0 = treeModel.getChild(node, 0);
final NodeRef childNode1 = treeModel.getChild(node, 1);
doPostOrderTraversal(childNode0);
doPostOrderTraversal(childNode1);
if (!treeModel.isRoot(node)) {
final int childNumber0 = childNode0.getNumber();
final int childNumber1 = childNode1.getNumber();
// precision values
final double precision0 = postP[childNumber0];
final double precision1 = postP[childNumber1];
final double thisPrecision = 1 / traitModel.getRescaledBranchLengthForPrecision(node);
double tp = precision0 + precision1;
postP[thisNumber] = tp * thisPrecision / (tp + thisPrecision);
//mean values
for (int j = 0; j < dim; j++) {
postMeans[thisNumber][j] = (precision0 * postMeans[childNumber0][j] + precision1 * postMeans[childNumber1][j]) / (precision0 + precision1);
}
}
}
public double[] getNodeTrait(NodeRef node) {
int index = node.getNumber();
double[] traitValue = tipTraitParameter.getParameter(index).getParameterValues();
return traitValue;
}
public double getNodeTrait(NodeRef node, int entry) {
int index = node.getNumber();
double traitValue = tipTraitParameter.getParameter(index).getParameterValue(entry);
return traitValue;
}
public void setNodeTrait(NodeRef node, double[] traitValue) {
int index = node.getNumber();
for (int i = 0; i < dim; i++) {
tipTraitParameter.getParameter(index).setParameterValue(i, traitValue[i]);
}
traitModel.getTraitParameter().getParameter(index).fireParameterChangedEvent();
}
public void setNodeTrait(NodeRef node, int entry, double traitValue) {
int index = node.getNumber();
tipTraitParameter.getParameter(index).setParameterValue(entry, traitValue);
// traitModel.getTraitParameter().getParameter(index).fireParameterChangedEvent();
}
//Fill out partial mean and precision values in pre order
public void doPreOrderTraversal(NodeRef node) { // TODO This should be computed IntegratedMultivariateTraitLikelihood
final int thisNumber = node.getNumber();
if (treeModel.isRoot(node)) {
preP[thisNumber] = rootPriorSampleSize;
for (int j = 0; j < dim; j++) {
preMeans[thisNumber][j] = rootPriorMean[j];
}
} else {
final NodeRef parentNode = treeModel.getParent(node);
final NodeRef sibNode = getSisterNode(node);
final int parentNumber = parentNode.getNumber();
final int sibNumber = sibNode.getNumber();
/*
if (treeModel.isRoot(parentNode)){
//partial precisions
final double precisionParent = rootPriorSampleSize;
final double precisionSib = postP[sibNumber];
final double thisPrecision=1/treeModel.getBranchLength(node);
double tp= precisionParent + precisionSib;
preP[thisNumber]= tp*thisPrecision/(tp+thisPrecision);
//partial means
for (int j =0; j<dim;j++){
preMeans[thisNumber][j] = (precisionParent*preMeans[parentNumber][j] + precisionSib*rootPriorMean[j])/(precisionParent+precisionSib);
}
}else{
*/
//partial precisions
final double precisionParent = preP[parentNumber];
final double precisionSib = postP[sibNumber];
final double thisPrecision = 1 / traitModel.getRescaledBranchLengthForPrecision(node);
double tp = precisionParent + precisionSib;
preP[thisNumber] = tp * thisPrecision / (tp + thisPrecision);
//partial means
for (int j = 0; j < dim; j++) {
preMeans[thisNumber][j] = (precisionParent * preMeans[parentNumber][j] + precisionSib * postMeans[sibNumber][j]) / (precisionParent + precisionSib);
}
}
if (treeModel.isExternal(node)) {
return;
} else {
doPreOrderTraversal(treeModel.getChild(node, 0));
doPreOrderTraversal(treeModel.getChild(node, 1));
}
}
public NodeRef getSisterNode(NodeRef node) {
NodeRef sib0 = treeModel.getChild(treeModel.getParent(node), 0);
NodeRef sib1 = treeModel.getChild(treeModel.getParent(node), 1);
if (sib0 == node) {
return sib1;
} else return sib0;
}
public double sampleNode(NodeRef node) {
final int thisNumber = node.getNumber();
double[] traitValue = getNodeTrait(node);
double[] mean = new double[dim];
for (int i = 0; i < dim; i++) {
mean[i] = preMeans[thisNumber][i];
}
double p = preP[thisNumber];
double[][] thisP = new double[dim][dim];
for (int i = 0; i < dim; i++) {
for (int j = 0; j < dim; j++) {
thisP[i][j] = p * precisionParam.getParameterValue(i, j);
}
}
/* Sample it all traits together as one multivariate normal
*
MultivariateNormalDistribution distribution = new MultivariateNormalDistribution(mean, thisP);
double[] oldValue = getNodeTrait(node);
double[] value = distribution.nextMultivariateNormal();
setNodeTrait(node,value);
double pOld = distribution.logPdf(oldValue[]);
double pNew = distribution.logPdf(value[]);
// printInformation(oldValue[0]);
// printInformation(value[0]);
*/
// double[] newTraitValue = getNodeTrait(node);
//double pNew = distribution.logPdf(newTraitValue);
/////////// Individually gibbs sample each entry in the vector
// for(int entry=0;entry<dim; entry++){
int entry = MathUtils.nextInt(dim);
double thisMean = getConditionalMean(entry, thisP, traitValue, mean);
double SD = Math.sqrt(1 / thisP[entry][entry]);
double oldValue = getNodeTrait(node, entry);
double value = MathUtils.nextGaussian();
value *= SD;
value += thisMean;
// printInformation(oldValue);
// printInformation(value);
NormalDistribution distribution = new NormalDistribution(thisMean, SD);
double pOld = distribution.logPdf(oldValue);
double pNew = distribution.logPdf(value);
setNodeTrait(node, entry, value);
double logq = pOld - pNew;
traitModel.getTraitParameter().getParameter(thisNumber).fireParameterChangedEvent();
return logq;
}
public double sampleNode2(NodeRef node) {
final int thisNumber = node.getNumber();
// double[] traitValue = getNodeTrait(node);
// double[] mean = new double[dim];
// for (int i = 0; i < dim; i++) {
// mean[i] = preMeans[thisNumber][i];
// }
//
// double p = preP[thisNumber];
//
// double[][] thisP = new double[dim][dim];
//
// for (int i = 0; i < dim; i++) {
// for (int j = 0; j < dim; j++) {
//
// thisP[i][j] = p * precisionParam.getParameterValue(i, j);
//
// }
// }
double[] mean=traitModel.getConditionalMean(thisNumber);
double[][] thisP=traitModel.getConditionalPrecision(thisNumber);
/* Sample it all traits together as one multivariate normal
*/
MultivariateNormalDistribution distribution = new MultivariateNormalDistribution(mean, thisP);
double[] oldValue = getNodeTrait(node);
double[] value = oldValue;
int attempt = 0;
boolean validTip = false;
while (!validTip & attempt < 10000) {
value = distribution.nextMultivariateNormal();
setNodeTrait(node, value);
if (latentLiability.validTraitForTip(thisNumber)) {
validTip = true;
}
attempt++;
} // TODO Failure rate should be stored somewhere and polled later for diagnostics
// printInformation((double)attempt);
double pOld = distribution.logPdf(oldValue);
double pNew = distribution.logPdf(value);
double logq = pOld - pNew;
traitModel.getTraitParameter().getParameter(thisNumber).fireParameterChangedEvent();
return logq;
}
private double getConditionalMean(int entry, double[][] thisP, double[] traitValue, double[] mean) {
double sumProd = 0;
for (int i = 0; i < dim; i++) {
if (i != entry) sumProd += thisP[entry][i] * (traitValue[i] - mean[i]);
}
double condMean = mean[entry] - sumProd / thisP[entry][entry];
return condMean;
}
public String getPerformanceSuggestion() {
return null;
}
public String getOperatorName() {
return LATENT_LIABILITY_GIBBS_OPERATOR;
}
public static dr.xml.XMLObjectParser PARSER = new dr.xml.AbstractXMLObjectParser() {
public String getParserName() {
return LATENT_LIABILITY_GIBBS_OPERATOR;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
if (xo.getChildCount() != 3) {
throw new XMLParseException(
"Element with id = '" + xo.getName() + "' should contain:\n" +
"\t 1 conjugate multivariateTraitLikelihood, 1 latentLiabilityLikelihood and one parameter \n"
);
}
double weight = xo.getDoubleAttribute(WEIGHT);
FullyConjugateMultivariateTraitLikelihood traitModel = (FullyConjugateMultivariateTraitLikelihood) xo.getChild(FullyConjugateMultivariateTraitLikelihood.class);
LatentTruncation LLModel = (LatentTruncation) xo.getChild(LatentTruncation.class);
// MultinomialLatentLiabilityLikelihood LLModel = (MultinomialLatentLiabilityLikelihood) xo.getChild(MultinomialLatentLiabilityLikelihood.class);
CompoundParameter tipTraitParameter = (CompoundParameter) xo.getChild(CompoundParameter.class);
return new LatentLiabilityGibbs(traitModel, LLModel, tipTraitParameter, weight);
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "This element returns a gibbs sampler on tip latent trais for latent liability model.";
}
public Class getReturnType() {
return MCMCOperator.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{
AttributeRule.newDoubleRule(WEIGHT),
new ElementRule(FullyConjugateMultivariateTraitLikelihood.class, "The model for the latent random variables"),
new ElementRule(LatentTruncation.class, "The model that links latent and observed variables"),
// new ElementRule(MultinomialLatentLiabilityLikelihood.class, "The model that links latent and observed variables"),
new ElementRule(CompoundParameter.class, "The parameter of tip locations from the tree")
};
};
}