/* * TraitRateGibbsOperator.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.operators; import dr.evolution.tree.MultivariateTraitTree; import dr.evolution.tree.NodeRef; import dr.evomodel.branchratemodel.ArbitraryBranchRates; import dr.evomodel.continuous.AbstractMultivariateTraitLikelihood; import dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood; import dr.inference.distribution.DistributionLikelihood; import dr.inference.distribution.GammaDistributionModel; import dr.inference.model.MatrixParameter; import dr.inference.operators.GibbsOperator; import dr.inference.operators.MCMCOperator; import dr.inference.operators.SimpleMCMCOperator; import dr.math.distributions.GammaDistribution; import dr.xml.*; import java.util.logging.Logger; /** * Gibbs samples each of AbritraryBranchRates when their prior is a gamma distribution * * @author Marc A. Suchard */ public class TraitRateGibbsOperator extends SimpleMCMCOperator implements GibbsOperator { public static final String GIBBS_OPERATOR = "traitRateGibbsOperator"; private final MultivariateTraitTree treeModel; private final MatrixParameter precisionMatrixParameter; private final AbstractMultivariateTraitLikelihood traitModel; private final GammaDistributionModel ratePriorModel; private final GammaDistribution ratePrior; private final ArbitraryBranchRates branchRateModel; private final int dim; private final String traitName; public TraitRateGibbsOperator(AbstractMultivariateTraitLikelihood traitModel, ArbitraryBranchRates branchRateModel, GammaDistributionModel ratePriorModel, GammaDistribution ratePrior) { super(); this.traitModel = traitModel; this.treeModel = traitModel.getTreeModel(); this.precisionMatrixParameter = (MatrixParameter) traitModel.getDiffusionModel().getPrecisionParameter(); this.traitName = traitModel.getTraitName(); this.branchRateModel = branchRateModel; this.ratePriorModel = ratePriorModel; this.ratePrior = ratePrior; this.dim = treeModel.getMultivariateNodeTrait(treeModel.getRoot(), traitName).length; boolean hasDistributionModel = ratePriorModel == null; boolean hasDistribution = ratePrior == null; if (traitModel instanceof IntegratedMultivariateTraitLikelihood) { throw new RuntimeException("Only implemented for a SampledMultivariateTraitLikelihood"); } if ((hasDistribution && hasDistributionModel) || (!hasDistribution && !hasDistributionModel)) { throw new RuntimeException("Can only provide one prior density in TraitRateGibbsOperation"); } if (!branchRateModel.usingReciprocal()) { throw new RuntimeException("ArbitraryBranchRates in TraitRateGibbsOperatior must use reciprocal rates"); } Logger.getLogger("dr.evomodel").info("Using Gibbs operator and trait rates"); } public int getStepCount() { return 1; } private void sampleRateForNode(NodeRef child, double[][] precision, double priorShape, double priorRate) { NodeRef parent = treeModel.getParent(child); final double[] trait = treeModel.getMultivariateNodeTrait(child, traitName); final double[] parentTrait = treeModel.getMultivariateNodeTrait(parent, traitName); final double precisionScalar = branchRateModel.getBranchRate(treeModel, child) / traitModel.getRescaledBranchLengthForPrecision(child); for (int i = 0; i < dim; i++) { trait[i] -= parentTrait[i]; } double SSE = 0; for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) { SSE += trait[i] * precision[i][j] * trait[j]; } } final double gammaShape = priorShape + 0.5 * dim; final double gammaRate = priorRate + 0.5 * SSE * precisionScalar; final double newValue = GammaDistribution.nextGamma(gammaShape, 1.0 / gammaRate); // Store the reciprocal value as the rate (\propto variance) branchRateModel.setBranchRate(treeModel, child, 1.0 / newValue); } public double doOperation() { double[][] precision = precisionMatrixParameter.getParameterAsMatrix(); double priorShape; double priorRate; if (ratePriorModel != null) { priorShape = ratePriorModel.getShape(); priorRate = 1.0 / ratePriorModel.getScale(); } else { priorShape = ratePrior.getShape(); priorRate = 1.0 / ratePrior.getScale(); } for (int i = 0; i < treeModel.getNodeCount(); i++) { NodeRef node = treeModel.getNode(i); if (node != treeModel.getRoot()) { sampleRateForNode(node, precision, priorShape, priorRate); } } return 0; } public String getPerformanceSuggestion() { return null; } public String getOperatorName() { return GIBBS_OPERATOR; } public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserName() { return GIBBS_OPERATOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { double weight = xo.getDoubleAttribute(WEIGHT); AbstractMultivariateTraitLikelihood traitLikelihood = (AbstractMultivariateTraitLikelihood) xo.getChild(AbstractMultivariateTraitLikelihood.class); ArbitraryBranchRates branchRates = (ArbitraryBranchRates) xo.getChild(ArbitraryBranchRates.class); DistributionLikelihood priorLikelihood = (DistributionLikelihood) xo.getChild(DistributionLikelihood.class); GammaDistributionModel gammaPriorModel = null; GammaDistribution gammaPrior = null; if (priorLikelihood.getDistribution() instanceof GammaDistributionModel) { gammaPriorModel = (GammaDistributionModel) priorLikelihood.getDistribution(); } else if (priorLikelihood.getDistribution() instanceof GammaDistribution) { gammaPrior = (GammaDistribution) priorLikelihood.getDistribution(); } else { throw new XMLParseException("Currently only works with a GammaDistributionModel or GammaDistribution"); } if (!branchRates.usingReciprocal()) { throw new XMLParseException( "Gibbs sampling of rates only works with reciprocal rates under an ArbitraryBranchRates model"); } TraitRateGibbsOperator operator = new TraitRateGibbsOperator(traitLikelihood, branchRates, gammaPriorModel, gammaPrior); operator.setWeight(weight); return operator; } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "This element returns a multivariate Gibbs operator on traits for possible all nodes."; } public Class getReturnType() { return MCMCOperator.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newDoubleRule(WEIGHT), new ElementRule(AbstractMultivariateTraitLikelihood.class), new ElementRule(ArbitraryBranchRates.class), new ElementRule(DistributionLikelihood.class), }; }; }