/* * TraitGibbsOperator.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.operators; import dr.evolution.tree.MultivariateTraitTree; import dr.evolution.tree.NodeRef; import dr.evolution.util.Taxon; import dr.evomodel.continuous.AbstractMultivariateTraitLikelihood; import dr.evomodel.continuous.SampledMultivariateTraitLikelihood; import dr.geo.GeoSpatialCollectionModel; import dr.geo.GeoSpatialDistribution; import dr.inference.distribution.MultivariateDistributionLikelihood; import dr.inference.model.MatrixParameter; import dr.inference.operators.GibbsOperator; import dr.inference.operators.MCMCOperator; import dr.inference.operators.SimpleMCMCOperator; import dr.math.MathUtils; import dr.math.distributions.MultivariateDistribution; import dr.math.distributions.MultivariateNormalDistribution; import dr.math.matrixAlgebra.SymmetricMatrix; import dr.xml.*; import java.util.HashMap; import java.util.Map; import java.util.logging.Logger; /** * @author Marc Suchard */ public class TraitGibbsOperator extends SimpleMCMCOperator implements GibbsOperator { public static final String GIBBS_OPERATOR = "traitGibbsOperator"; public static final String INTERNAL_ONLY = "onlyInternalNodes"; public static final String TIP_WITH_PRIORS_ONLY = "onlyTipsWithPriors"; public static final String NODE_PRIOR = "nodePrior"; public static final String NODE_LABEL = "taxon"; public static final String ROOT_PRIOR = "rootPrior"; private final MultivariateTraitTree treeModel; private final MatrixParameter precisionMatrixParameter; private final SampledMultivariateTraitLikelihood traitModel; private final int dim; private final String traitName; private Map<Taxon, GeoSpatialDistribution> nodeGeoSpatialPrior; private Map<Taxon, MultivariateNormalDistribution> nodeMVNPrior; private GeoSpatialCollectionModel parameterPrior = null; private boolean onlyInternalNodes = true; private boolean onlyTipsWithPriors = true; private boolean sampleRoot = false; private double[] rootPriorMean; private double[][] rootPriorPrecision; private final int maxTries = 10000; public TraitGibbsOperator(SampledMultivariateTraitLikelihood traitModel, boolean onlyInternalNodes, boolean onlyTipsWithPriors) { super(); this.traitModel = traitModel; this.treeModel = traitModel.getTreeModel(); this.precisionMatrixParameter = (MatrixParameter) traitModel.getDiffusionModel().getPrecisionParameter(); this.traitName = traitModel.getTraitName(); this.onlyInternalNodes = onlyInternalNodes; this.onlyTipsWithPriors = onlyTipsWithPriors; this.dim = treeModel.getMultivariateNodeTrait(treeModel.getRoot(), traitName).length; Logger.getLogger("dr.evomodel").info("Using *NEW* trait Gibbs operator"); } public void setRootPrior(MultivariateNormalDistribution rootPrior) { rootPriorMean = rootPrior.getMean(); rootPriorPrecision = rootPrior.getScaleMatrix(); sampleRoot = true; } public void setTaxonPrior(Taxon taxon, MultivariateDistribution distribution) { if (distribution instanceof GeoSpatialDistribution) { if (nodeGeoSpatialPrior == null) { nodeGeoSpatialPrior = new HashMap<Taxon, GeoSpatialDistribution>(); } nodeGeoSpatialPrior.put(taxon, (GeoSpatialDistribution) distribution); } else if (distribution instanceof MultivariateNormalDistribution) { if (nodeMVNPrior == null) { nodeMVNPrior = new HashMap<Taxon, MultivariateNormalDistribution>(); } nodeMVNPrior.put(taxon, (MultivariateNormalDistribution) distribution); } else { throw new RuntimeException("Only flat/truncated geospatial and multivariate normal distributions allowed"); } } public void setParameterPrior(GeoSpatialCollectionModel distribution) { parameterPrior = distribution; } public int getStepCount() { return 1; } private boolean nodeGeoSpatialPriorExists(NodeRef node) { return nodeGeoSpatialPrior != null && nodeGeoSpatialPrior.containsKey(treeModel.getNodeTaxon(node)); } private boolean nodeMVNPriorExists(NodeRef node) { return nodeMVNPrior != null && nodeMVNPrior.containsKey(treeModel.getNodeTaxon(node)); } public double doOperation() { NodeRef node = null; final NodeRef root = treeModel.getRoot(); while (node == null) { if (onlyInternalNodes) node = treeModel.getInternalNode(MathUtils.nextInt( treeModel.getInternalNodeCount())); else { node = treeModel.getNode(MathUtils.nextInt( treeModel.getNodeCount())); if (onlyTipsWithPriors && (treeModel.getChildCount(node) == 0) && // Is a tip !nodeGeoSpatialPriorExists(node)) { // Does not have a prior node = null; } } if (!sampleRoot && node == root) node = null; } // select any internal (or internal/external) node final double[] initialValue = treeModel.getMultivariateNodeTrait(node, traitName); MeanPrecision mp; if (node != root) mp = operateNotRoot(node); else mp = operateRoot(node); final Taxon taxon = treeModel.getNodeTaxon(node); // final boolean nodePriorExists = nodeGeoSpatialPrior != null && nodeGeoSpatialPrior.containsKey(taxon); final boolean nodePriorExists = nodeGeoSpatialPriorExists(node); // if (!onlyInternalNodes) { // final boolean isTip = (treeModel.getChildCount(node) == 0); // if (!nodePriorExists && isTip) // System.err.println("Warning: sampling taxon '"+treeModel.getNodeTaxon(node).getId() // +"' tip trait without a prior!!!"); // } int count = 0; final boolean parameterPriorExists = parameterPrior != null; double[] draw; do { do { if (count > maxTries) { treeModel.setMultivariateTrait(node, traitName, initialValue); // TODO Add to MTT interface throw new RuntimeException("Truncated Gibbs is stuck!"); } draw = MultivariateNormalDistribution.nextMultivariateNormalPrecision( mp.mean, mp.precision); count++; } while (nodePriorExists && // There is a prior for this node (nodeGeoSpatialPrior.get(taxon)).logPdf(draw) == Double.NEGATIVE_INFINITY); // And draw is invalid under prior // TODO Currently only works for flat/truncated priors, make work for MVN treeModel.setMultivariateTrait(node, traitName, draw); } while (parameterPriorExists && (parameterPrior.getLogLikelihood() == Double.NEGATIVE_INFINITY)); return 0; } private MeanPrecision operateNotRoot(NodeRef node) { double[][] precision = precisionMatrixParameter.getParameterAsMatrix(); NodeRef parent = treeModel.getParent(node); double[] mean = new double[dim]; double weight = 1.0 / traitModel.getRescaledBranchLengthForPrecision(node); double[] trait = treeModel.getMultivariateNodeTrait(parent, traitName); for (int i = 0; i < dim; i++) mean[i] = trait[i] * weight; double weightTotal = weight; for (int j = 0; j < treeModel.getChildCount(node); j++) { NodeRef child = treeModel.getChild(node, j); trait = treeModel.getMultivariateNodeTrait(child, traitName); weight = 1.0 / traitModel.getRescaledBranchLengthForPrecision(child); for (int i = 0; i < dim; i++) mean[i] += trait[i] * weight; weightTotal += weight; } for (int i = 0; i < dim; i++) { mean[i] /= weightTotal; for (int j = i; j < dim; j++) precision[j][i] = precision[i][j] *= weightTotal; } if (nodeMVNPriorExists(node)) { throw new RuntimeException("Still trying to implement multivariate normal taxon priors"); } return new MeanPrecision(mean, precision); } class MeanPrecision { final double[] mean; final double[][] precision; MeanPrecision(double[] mean, double[][] precision) { this.mean = mean; this.precision = precision; } } private MeanPrecision operateRoot(NodeRef node) { double[] trait; double weightTotal = 0.0; double[] weightedAverage = new double[dim]; double[][] precision = precisionMatrixParameter.getParameterAsMatrix(); for (int k = 0; k < treeModel.getChildCount(node); k++) { NodeRef child = treeModel.getChild(node, k); trait = treeModel.getMultivariateNodeTrait(child, traitName); final double weight = 1.0 / traitModel.getRescaledBranchLengthForPrecision(child); for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) weightedAverage[i] += precision[i][j] * weight * trait[j]; } weightTotal += weight; } for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) { weightedAverage[i] += rootPriorPrecision[i][j] * rootPriorMean[j]; precision[i][j] = precision[i][j] * weightTotal + rootPriorPrecision[i][j]; } } double[][] variance = new SymmetricMatrix(precision).inverse().toComponents(); trait = new double[dim]; for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) trait[i] += variance[i][j] * weightedAverage[j]; } return new MeanPrecision(trait, precision); } public String getPerformanceSuggestion() { return null; } public String getOperatorName() { return GIBBS_OPERATOR; } public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserName() { return GIBBS_OPERATOR; } private final String[] names = {GIBBS_OPERATOR, "internalTraitGibbsOperator"}; public String[] getParserNames() { return names; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { double weight = xo.getDoubleAttribute(WEIGHT); boolean onlyInternalNodes = xo.getAttribute(INTERNAL_ONLY, true); boolean onlyTipsWithPriors = xo.getAttribute(TIP_WITH_PRIORS_ONLY, true); SampledMultivariateTraitLikelihood traitModel = (SampledMultivariateTraitLikelihood) xo.getChild(AbstractMultivariateTraitLikelihood.class); TraitGibbsOperator operator = new TraitGibbsOperator(traitModel, onlyInternalNodes, onlyTipsWithPriors); operator.setWeight(weight); // Get root prior XMLObject cxo = xo.getChild(ROOT_PRIOR); if (cxo != null) { MultivariateDistributionLikelihood rootPrior = (MultivariateDistributionLikelihood) cxo.getChild(MultivariateDistributionLikelihood.class); if (!(rootPrior.getDistribution() instanceof MultivariateDistribution)) throw new XMLParseException("Only multivariate normal priors allowed for Gibbs sampling the root trait"); operator.setRootPrior((MultivariateNormalDistribution) rootPrior.getDistribution()); } // Get node priors for (int i = 0; i < xo.getChildCount(); i++) { if (xo.getChild(i) instanceof MultivariateDistributionLikelihood) { MultivariateDistribution dist = ((MultivariateDistributionLikelihood) xo.getChild(i)).getDistribution(); if (dist instanceof GeoSpatialDistribution) { GeoSpatialDistribution prior = (GeoSpatialDistribution) dist; String nodeLabel = prior.getLabel(); Taxon taxon = getTaxon(traitModel.getTreeModel(), nodeLabel); operator.setTaxonPrior(taxon, prior); System.err.println("Adding truncated prior for taxon '" + taxon + "'"); } } } GeoSpatialCollectionModel collectionModel = (GeoSpatialCollectionModel) xo.getChild(GeoSpatialCollectionModel.class); if (collectionModel != null) { operator.setParameterPrior(collectionModel); System.err.println("Adding truncated prior '" + collectionModel.getId() + "' for parameter '" + collectionModel.getParameter().getId() + "'"); } return operator; } private Taxon getTaxon(MultivariateTraitTree treeModel, String taxonLabel) throws XMLParseException { // Get taxon node from tree int index = treeModel.getTaxonIndex(taxonLabel); if (index == -1) { throw new XMLParseException("Taxon '" + taxonLabel + "' not found for geoSpatialDistribution element in traitGibbsOperator element"); } return treeModel.getTaxon(index); } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "This element returns a multivariate Gibbs operator on traits for possible all nodes."; } public Class getReturnType() { return MCMCOperator.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newDoubleRule(WEIGHT), AttributeRule.newBooleanRule(INTERNAL_ONLY, true), AttributeRule.newBooleanRule(TIP_WITH_PRIORS_ONLY, true), new ElementRule(SampledMultivariateTraitLikelihood.class), // new ElementRule(NODE_PRIOR, new XMLSyntaxRule[] { // AttributeRule.newStringRule(NODE_LABEL), // new ElementRule(MultivariateDistributionLikelihood.class), // }), new ElementRule(MultivariateDistributionLikelihood.class, 0, Integer.MAX_VALUE), new ElementRule(ROOT_PRIOR, new XMLSyntaxRule[]{ new ElementRule(MultivariateDistributionLikelihood.class) }, true), new ElementRule(GeoSpatialCollectionModel.class, true), }; }; }