/* * OldLatentLiabilityGibbsOperator.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.operators; import dr.evolution.tree.MultivariateTraitTree; import dr.evolution.tree.NodeRef; import dr.evomodel.continuous.AbstractMultivariateTraitLikelihood; import dr.evomodel.continuous.BinaryLatentLiabilityLikelihood; import dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood; import dr.inference.model.MatrixParameter; import dr.inference.operators.GibbsOperator; import dr.inference.operators.MCMCOperator; import dr.inference.operators.SimpleMCMCOperator; import dr.math.MathUtils; import dr.util.Citable; import dr.xml.*; import java.util.logging.Logger; /** * @author Marc A. Suchard * @author Joe Felsenstein */ public class OldLatentLiabilityGibbsOperator extends SimpleMCMCOperator implements GibbsOperator { public static final String GIBBS_OPERATOR = "oldLatentLiabilityGibbsOperator"; // public static final String INTERNAL_ONLY = "onlyInternalNodes"; // public static final String TIP_WITH_PRIORS_ONLY = "onlyTipsWithPriors"; // public static final String NODE_PRIOR = "nodePrior"; // public static final String NODE_LABEL = "taxon"; // public static final String ROOT_PRIOR = "rootPrior"; private final MultivariateTraitTree treeModel; private final MatrixParameter precisionMatrixParameter; private final IntegratedMultivariateTraitLikelihood traitModel; private final BinaryLatentLiabilityLikelihood liabilityLikelihood; private final int dim; private final String traitName; // private Map<Taxon, GeoSpatialDistribution> nodeGeoSpatialPrior; // private Map<Taxon, MultivariateNormalDistribution> nodeMVNPrior; // private GeoSpatialCollectionModel parameterPrior = null; // private boolean onlyInternalNodes = true; // private boolean onlyTipsWithPriors = true; // private boolean sampleRoot = false; // private double[] rootPriorMean; // private double[][] rootPriorPrecision; private final int maxTries = 10000; public OldLatentLiabilityGibbsOperator(IntegratedMultivariateTraitLikelihood traitModel, BinaryLatentLiabilityLikelihood liabilityLikelihood) { super(); this.traitModel = traitModel; this.liabilityLikelihood = liabilityLikelihood; this.treeModel = traitModel.getTreeModel(); this.precisionMatrixParameter = (MatrixParameter) traitModel.getDiffusionModel().getPrecisionParameter(); this.traitName = traitModel.getTraitName(); this.dim = traitModel.getDimTrait(); StringBuilder sb = new StringBuilder(); sb.append("Using a latent trait Gibbs operator. Please cite:"); sb.append(Citable.Utils.getCitationString(liabilityLikelihood)); Logger.getLogger("dr.evomodel.continuous").info(sb.toString()); } // public void setRootPrior(MultivariateNormalDistribution rootPrior) { // rootPriorMean = rootPrior.getMean(); // rootPriorPrecision = rootPrior.getScaleMatrix(); // sampleRoot = true; // } // public void setTaxonPrior(Taxon taxon, MultivariateDistribution distribution) { // // if (distribution instanceof GeoSpatialDistribution) { // if (nodeGeoSpatialPrior == null) { // nodeGeoSpatialPrior = new HashMap<Taxon, GeoSpatialDistribution>(); // } // nodeGeoSpatialPrior.put(taxon, (GeoSpatialDistribution)distribution); // // } else if (distribution instanceof MultivariateNormalDistribution) { // if (nodeMVNPrior == null) { // nodeMVNPrior = new HashMap<Taxon, MultivariateNormalDistribution>(); // } // nodeMVNPrior.put(taxon, (MultivariateNormalDistribution)distribution); // } else { // throw new RuntimeException("Only flat/truncated geospatial and multivariate normal distributions allowed"); // } // } // // public void setParameterPrior(GeoSpatialCollectionModel distribution) { // parameterPrior = distribution; // } public int getStepCount() { return 1; } // private boolean nodeGeoSpatialPriorExists(NodeRef node) { // return nodeGeoSpatialPrior != null && nodeGeoSpatialPrior.containsKey(treeModel.getNodeTaxon(node)); // } // // private boolean nodeMVNPriorExists(NodeRef node) { // return nodeMVNPrior != null && nodeMVNPrior.containsKey(treeModel.getNodeTaxon(node)); // } public double doOperation() { traitModel.redrawAncestralStates(); NodeRef node = treeModel.getNode(MathUtils.nextInt(treeModel.getExternalNodeCount())); int tip = node.getNumber(); // Draw truncated MVN using rejection sampling do { // Nothing } while (!liabilityLikelihood.validTraitForTip(tip)); // NodeRef node = null; // final NodeRef root = treeModel.getRoot(); // while (node == null) { // if (onlyInternalNodes) // node = treeModel.getInternalNode(MathUtils.nextInt( // treeModel.getInternalNodeCount())); // else { // node = treeModel.getNode(MathUtils.nextInt( // treeModel.getNodeCount())); // if (onlyTipsWithPriors && // (treeModel.getChildCount(node) == 0) && // Is a tip // !nodeGeoSpatialPriorExists(node)) { // Does not have a prior // node = null; // } // } // if (!sampleRoot && node == root) // node = null; // } // select any internal (or internal/external) node // // final double[] initialValue = treeModel.getMultivariateNodeTrait(node,traitName); // // MeanPrecision mp; // // if (node != root) // mp = operateNotRoot(node); // else // mp = operateRoot(node); // // final Taxon taxon = treeModel.getNodeTaxon(node); // // final boolean nodePriorExists = nodeGeoSpatialPriorExists(node); // // int count = 0; // // final boolean parameterPriorExists = parameterPrior != null; // // double[] draw; // // do { // do { // if (count > maxTries) { // treeModel.setMultivariateTrait(node,traitName,initialValue); // throw new OperatorFailedException("Truncated Gibbs is stuck!"); // } // // draw = MultivariateNormalDistribution.nextMultivariateNormalPrecision( // mp.mean, mp.precision); // count++; // // } while (nodePriorExists && // There is a prior for this node // (nodeGeoSpatialPrior.get(taxon)).logPdf(draw) == Double.NEGATIVE_INFINITY); // And draw is invalid under prior // // treeModel.setMultivariateTrait(node, traitName, draw); // // } while (parameterPriorExists && // (parameterPrior.getLogLikelihood() == Double.NEGATIVE_INFINITY)); return 0; } // private MeanPrecision operateNotRoot(NodeRef node) { // // double[][] precision = precisionMatrixParameter.getParameterAsMatrix(); // // NodeRef parent = treeModel.getParent(node); // // double[] mean = new double[dim]; // // double weight = 1.0 / traitModel.getRescaledBranchLength(node); // // double[] trait = treeModel.getMultivariateNodeTrait(parent, traitName); // // for (int i = 0; i < dim; i++) // mean[i] = trait[i] * weight; // // double weightTotal = weight; // for (int j = 0; j < treeModel.getChildCount(node); j++) { // NodeRef child = treeModel.getChild(node, j); // trait = treeModel.getMultivariateNodeTrait(child, traitName); // weight = 1.0 / traitModel.getRescaledBranchLength(child); // // for (int i = 0; i < dim; i++) // mean[i] += trait[i] * weight; // // weightTotal += weight; // } // // for (int i = 0; i < dim; i++) { // mean[i] /= weightTotal; // for (int j = i; j < dim; j++) // precision[j][i] = precision[i][j] *= weightTotal; // } // //// if (nodeMVNPriorExists(node)) { //// throw new RuntimeException("Still trying to implement multivariate normal taxon priors"); //// } // // return new MeanPrecision(mean,precision); // } class MeanPrecision { final double[] mean; final double[][] precision; MeanPrecision(double[] mean, double[][] precision) { this.mean = mean; this.precision = precision; } } // private MeanPrecision operateRoot(NodeRef node) { // // double[] trait; // double weightTotal = 0.0; // // double[] weightedAverage = new double[dim]; // // double[][] precision = precisionMatrixParameter.getParameterAsMatrix(); // // for (int k = 0; k < treeModel.getChildCount(node); k++) { // NodeRef child = treeModel.getChild(node, k); // trait = treeModel.getMultivariateNodeTrait(child, traitName); // final double weight = 1.0 / traitModel.getRescaledBranchLength(child); // // for (int i = 0; i < dim; i++) { // for (int j=0; j<dim; j++) // weightedAverage[i] += precision[i][j] * weight * trait[j]; // } // // weightTotal += weight; // } // // for (int i=0; i<dim; i++) { // for (int j=0; j<dim; j++) { // weightedAverage[i] += rootPriorPrecision[i][j] * rootPriorMean[j]; // precision[i][j] = precision[i][j] * weightTotal + rootPriorPrecision[i][j]; // } // } // // double[][] variance = new SymmetricMatrix(precision).inverse().toComponents(); // // trait = new double[dim]; // for (int i=0; i<dim; i++) { // for (int j=0; j<dim; j++) // trait[i] += variance[i][j] * weightedAverage[j]; // } // // return new MeanPrecision(trait,precision); // } public String getPerformanceSuggestion() { return null; } public String getOperatorName() { return GIBBS_OPERATOR; } public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserName() { return GIBBS_OPERATOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { double weight = xo.getDoubleAttribute(WEIGHT); // boolean onlyInternalNodes = xo.getAttribute(INTERNAL_ONLY, true); // boolean onlyTipsWithPriors = xo.getAttribute(TIP_WITH_PRIORS_ONLY, true); // boolean onlyInternalNodes = true; // boolean onlyTipsWithPriors = true; IntegratedMultivariateTraitLikelihood traitModel = (IntegratedMultivariateTraitLikelihood) xo.getChild(AbstractMultivariateTraitLikelihood.class); BinaryLatentLiabilityLikelihood liabilityLikelihood = (BinaryLatentLiabilityLikelihood) xo.getChild(BinaryLatentLiabilityLikelihood.class); OldLatentLiabilityGibbsOperator operator = new OldLatentLiabilityGibbsOperator(traitModel, liabilityLikelihood); operator.setWeight(weight); return operator; } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "This element returns a multivariate Gibbs operator on traits for tip nodes under a latent liability model."; } public Class getReturnType() { return MCMCOperator.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newDoubleRule(WEIGHT), new ElementRule(IntegratedMultivariateTraitLikelihood.class), new ElementRule(BinaryLatentLiabilityLikelihood.class), }; }; }