/* * MultivariateDistributionLikelihood.java * * Copyright (c) 2002-2016 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.distribution; import dr.evomodel.continuous.FullyConjugateMultivariateTraitLikelihood; import dr.evomodel.continuous.TreeTraitNormalDistributionModel; import dr.inference.model.*; import dr.inferencexml.distribution.DistributionLikelihoodParser; import dr.math.distributions.*; import dr.util.Attribute; import dr.util.Citable; import dr.util.Transform; import dr.xml.*; import java.util.logging.Logger; /** * @author Marc Suchard * @author Guy Baele */ public class MultivariateDistributionLikelihood extends AbstractDistributionLikelihood { public static final String MVN_PRIOR = "multivariateNormalPrior"; public static final String MVN_MEAN = "meanParameter"; public static final String MVN_PRECISION = "precisionParameter"; public static final String MVN_CV = "coefficientOfVariation"; public static final String WISHART_PRIOR = "multivariateWishartPrior"; public static final String INV_WISHART_PRIOR = "multivariateInverseWishartPrior"; public static final String DIRICHLET_PRIOR = "dirichletParameterPrior"; public static final String SUM_TO_NUMBER_OF_ELEMENTS = "sumToNumberOfElements"; public static final String DF = "df"; public static final String SCALE_MATRIX = "scaleMatrix"; public static final String MVGAMMA_PRIOR = "multivariateGammaPrior"; public static final String MVGAMMA_SHAPE = "shapeParameter"; public static final String MVGAMMA_SCALE = "scaleParameter"; public static final String COUNTS = "countsParameter"; public static final String NON_INFORMATIVE = "nonInformative"; public static final String MULTIVARIATE_LIKELIHOOD = "multivariateDistributionLikelihood"; public static final String DATA_AS_MATRIX = "dataAsMatrix"; // public static final String TREE_TRAIT = "treeTraitNormalDistribution"; public static final String TREE_TRAIT = "treeTraitNormalDistributionLikelihood"; public static final String TREE_TRAIT_NORMAL = "treeTraitNormalDistribution"; public static final String ROOT_VALUE = "rootValue"; public static final String CONDITION = "conditionOnRoot"; public static final String DATA = "data"; private final MultivariateDistribution distribution; private final Transform[] transforms; private Parameter parameter = null; public MultivariateDistributionLikelihood(String name, ParametricMultivariateDistributionModel model) { this(name, model, null); } public MultivariateDistributionLikelihood(String name, ParametricMultivariateDistributionModel model, Transform[] transforms) { super(model); this.distribution = model; this.transforms = transforms; } public MultivariateDistributionLikelihood(String name, MultivariateDistribution distribution) { this(name, distribution, null); } public MultivariateDistributionLikelihood(String name, MultivariateDistribution distribution, Transform[] transforms) { super(new DefaultModel(name)); this.distribution = distribution; this.transforms = transforms; } public MultivariateDistributionLikelihood(MultivariateDistribution distribution) { this(distribution, null); } public MultivariateDistributionLikelihood(MultivariateDistribution distribution, Transform[] transforms) { this(distribution.getType(), distribution, transforms); } public String toString() { return getClass().getName() + "(" + getLogLikelihood() + ")"; } public double calculateLogLikelihood() { double logL = 0.0; for (Attribute<double[]> data : dataList) { double[] x = // (data instanceof Parameter) ? // ((Parameter) data).getParameterValues() : data.getAttributeValue(); if (transforms != null) { double[] y = new double[x.length]; for (int i = 0; i < x.length; ++i) { logL += transforms[i].getLogJacobian(x[i]); y[i] = transforms[i].transform(x[i]); } logL += distribution.logPdf(y); } else { logL += distribution.logPdf(x); } } return logL; } public void addData(Parameter parameter) { this.parameter = parameter; addData((Attribute<double[]>)parameter); } public Parameter getDataParameter() { return parameter; } @Override public void addData(Attribute<double[]> data) { super.addData(data); if (data instanceof Variable && getModel() instanceof DefaultModel) { ((DefaultModel) getModel()).addVariable((Variable) data); } } public MultivariateDistribution getDistribution() { return distribution; } public static Transform[] parseListOfTransforms(XMLObject xo, int maxDim) throws XMLParseException { Transform[] transforms = null; boolean anyTransforms = false; for (int i = 0; i < xo.getChildCount(); ++i) { if (xo.getChild(i) instanceof Transform.ParsedTransform) { Transform.ParsedTransform t = (Transform.ParsedTransform) xo.getChild(i); if (transforms == null) { transforms = Transform.Util.getListOfNoTransforms(maxDim); } t.end = Math.max(t.end, maxDim); if (t.start < 0 || t.end < 0 || t.start > t.end) { throw new XMLParseException("Invalid bounds for transform in " + xo.getId()); } for (int j = t.start; j < t.end; j += t.every) { transforms[j] = t.transform; anyTransforms = true; } } } if (anyTransforms) { StringBuilder sb = new StringBuilder("Using distributional transforms in " + xo.getId() + "\n"); for (int i = 0; i < transforms.length; ++i) { if (transforms[i] != Transform.NONE) { sb.append("\t").append(transforms[i].getTransformName()).append(" on index ") .append(i + 1).append("\n"); } } Logger.getLogger("dr.utils.Transform").info(sb.toString()); } return transforms; } public static XMLObjectParser DIRICHLET_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return DIRICHLET_PRIOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean sumConstraint = false; if (xo.hasAttribute(SUM_TO_NUMBER_OF_ELEMENTS)) { sumConstraint = xo.getBooleanAttribute(SUM_TO_NUMBER_OF_ELEMENTS); } XMLObject cxo = xo.getChild(COUNTS); Parameter counts = (Parameter) cxo.getChild(Parameter.class); DirichletDistribution dirichlet = new DirichletDistribution(counts.getParameterValues(), sumConstraint); MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood( dirichlet); cxo = xo.getChild(DATA); for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof Parameter) { likelihood.addData((Parameter) cxo.getChild(j)); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element " + cxo.getName()); } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newBooleanRule(SUM_TO_NUMBER_OF_ELEMENTS, true), new ElementRule(COUNTS, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(DATA, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, 1, Integer.MAX_VALUE), }; public String getParserDescription() { return "Calculates the likelihood of some data under a Dirichlet distribution."; } public Class getReturnType() { return Likelihood.class; } }; public static XMLObjectParser INV_WISHART_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return INV_WISHART_PRIOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { int df = xo.getIntegerAttribute(DF); XMLObject cxo = xo.getChild(SCALE_MATRIX); MatrixParameter scaleMatrix = (MatrixParameter) cxo.getChild(MatrixParameter.class); InverseWishartDistribution invWishart = new InverseWishartDistribution(df, scaleMatrix.getParameterAsMatrix()); MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood( invWishart); cxo = xo.getChild(DATA); for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof MatrixParameter) { likelihood.addData((MatrixParameter) cxo.getChild(j)); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element " + cxo.getName()); } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newDoubleRule(DF), new ElementRule(SCALE_MATRIX, new XMLSyntaxRule[]{new ElementRule(MatrixParameter.class)}), }; public String getParserDescription() { return "Calculates the likelihood of some data under an Inverse-Wishart distribution."; } public Class getReturnType() { return Likelihood.class; } }; public static XMLObjectParser WISHART_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return WISHART_PRIOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { MultivariateDistributionLikelihood likelihood; if (xo.hasAttribute(NON_INFORMATIVE) && xo.getBooleanAttribute(NON_INFORMATIVE)) { // Make non-informative settings XMLObject cxo = xo.getChild(DATA); int dim = ((MatrixParameter) cxo.getChild(0)).getColumnDimension(); likelihood = new MultivariateDistributionLikelihood(new WishartDistribution(dim)); } else { if (!xo.hasAttribute(DF) || !xo.hasChildNamed(SCALE_MATRIX)) { throw new XMLParseException("Must specify both a df and scaleMatrix"); } double df = xo.getDoubleAttribute(DF); XMLObject cxo = xo.getChild(SCALE_MATRIX); MatrixParameter scaleMatrix = (MatrixParameter) cxo.getChild(MatrixParameter.class); likelihood = new MultivariateDistributionLikelihood( new WishartDistribution(df, scaleMatrix.getParameterAsMatrix()) ); } XMLObject cxo = xo.getChild(DATA); for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof MatrixParameter) { likelihood.addData((MatrixParameter) cxo.getChild(j)); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element " + cxo.getName()); } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules;{ rules = new XMLSyntaxRule[]{ AttributeRule.newBooleanRule(NON_INFORMATIVE, true), AttributeRule.newDoubleRule(DF, true), new ElementRule(SCALE_MATRIX, new XMLSyntaxRule[]{new ElementRule(MatrixParameter.class)}, true), new ElementRule(DATA, new XMLSyntaxRule[]{new ElementRule(MatrixParameter.class, 1, Integer.MAX_VALUE)} ) }; } public String getParserDescription() { return "Calculates the likelihood of some data under a Wishart distribution."; } public Class getReturnType() { return Likelihood.class; } }; public static XMLObjectParser MULTIVARIATE_LIKELIHOOD_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return MULTIVARIATE_LIKELIHOOD; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { XMLObject cxo = xo.getChild(DistributionLikelihoodParser.DISTRIBUTION); ParametricMultivariateDistributionModel distribution = (ParametricMultivariateDistributionModel) cxo.getChild(ParametricMultivariateDistributionModel.class); // Parse transforms here int maxDim = distribution.getMean().length; Transform[] transforms = parseListOfTransforms(xo, maxDim); MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood(xo.getId(), distribution, transforms); boolean dataAsMatrix = xo.getAttribute(DATA_AS_MATRIX, false); cxo = xo.getChild(DATA); if (cxo != null) { for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof Parameter) { Parameter data = (Parameter) cxo.getChild(j); if (data instanceof MatrixParameter) { MatrixParameter matrix = (MatrixParameter) data; if (dataAsMatrix) { likelihood.addData(matrix); } else { if (matrix.getParameter(0).getDimension() != distribution.getMean().length) throw new XMLParseException("dim(" + data.getStatisticName() + ") = " + matrix.getParameter(0).getDimension() + " is not equal to dim(" + distribution.getType() + ") = " + distribution.getMean().length + " in " + xo.getName() + "element"); for (int i = 0; i < matrix.getParameterCount(); i++) { likelihood.addData(matrix.getParameter(i)); } } } else { if (data.getDimension() != distribution.getMean().length) throw new XMLParseException("dim(" + data.getStatisticName() + ") = " + data.getDimension() + " is not equal to dim(" + distribution.getType() + ") = " + distribution.getMean().length + " in " + xo.getName() + "element"); likelihood.addData(data); } } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new ElementRule(DistributionLikelihoodParser.DISTRIBUTION, new XMLSyntaxRule[]{new ElementRule(ParametricMultivariateDistributionModel.class)} ), AttributeRule.newBooleanRule(DATA_AS_MATRIX, true), new ElementRule(Transform.ParsedTransform.class, 0, Integer.MAX_VALUE), new ElementRule(DATA, new XMLSyntaxRule[]{new ElementRule(Parameter.class, 1, Integer.MAX_VALUE)}, true) }; public String getParserDescription() { return "Calculates the likelihood of some data under a given multivariate distribution."; } public Class getReturnType() { return MultivariateDistributionLikelihood.class; } }; public static XMLObjectParser MVN_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return MVN_PRIOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { XMLObject cxo = xo.getChild(MVN_MEAN); Parameter mean = (Parameter) cxo.getChild(Parameter.class); cxo = xo.getChild(MVN_PRECISION); MatrixParameter precision = (MatrixParameter) cxo.getChild(MatrixParameter.class); if (mean.getDimension() != precision.getRowDimension() || mean.getDimension() != precision.getColumnDimension()) throw new XMLParseException("Mean and precision have wrong dimensions in " + xo.getName() + " element"); Transform[] transforms = parseListOfTransforms(xo, mean.getDimension()); MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood( new MultivariateNormalDistribution(mean.getParameterValues(), precision.getParameterAsMatrix()), transforms ); cxo = xo.getChild(DATA); if (cxo != null) { for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof Parameter) { Parameter data = (Parameter) cxo.getChild(j); if (data instanceof MatrixParameter) { MatrixParameter matrix = (MatrixParameter) data; if (matrix.getParameter(0).getDimension() != mean.getDimension()) throw new XMLParseException("dim(" + data.getStatisticName() + ") = " + matrix.getParameter(0).getDimension() + " is not equal to dim(" + mean.getStatisticName() + ") = " + mean.getDimension() + " in " + xo.getName() + "element"); for (int i = 0; i < matrix.getParameterCount(); i++) { likelihood.addData(matrix.getParameter(i)); } } else { if (data.getDimension() != mean.getDimension()) throw new XMLParseException("dim(" + data.getStatisticName() + ") = " + data.getDimension() + " is not equal to dim(" + mean.getStatisticName() + ") = " + mean.getDimension() + " in " + xo.getName() + "element"); likelihood.addData(data); } } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new ElementRule(MVN_MEAN, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(MVN_PRECISION, new XMLSyntaxRule[]{new ElementRule(MatrixParameter.class)}), new ElementRule(Transform.ParsedTransform.class, 0, Integer.MAX_VALUE), new ElementRule(DATA, new XMLSyntaxRule[]{new ElementRule(Parameter.class, 1, Integer.MAX_VALUE)}, true) }; public String getParserDescription() { return "Calculates the likelihood of some data under a given multivariate-normal distribution."; } public Class getReturnType() { return MultivariateDistributionLikelihood.class; } }; public static XMLObjectParser MVGAMMA_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return MVGAMMA_PRIOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { double[] shape; double[] scale; if (xo.hasChildNamed(MVGAMMA_SHAPE)) { XMLObject cxo = xo.getChild(MVGAMMA_SHAPE); shape = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues(); cxo = xo.getChild(MVGAMMA_SCALE); scale = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues(); if (shape.length != scale.length) throw new XMLParseException("Shape and scale have wrong dimensions in " + xo.getName() + " element"); } else { XMLObject cxo = xo.getChild(MVN_MEAN); double[] mean = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues(); cxo = xo.getChild(MVN_CV); double[] cv = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues(); if (mean.length != cv.length) throw new XMLParseException("Mean and CV have wrong dimensions in " + xo.getName() + " element"); final int dim = mean.length; shape = new double[dim]; scale = new double[dim]; for (int i = 0; i < dim; i++) { double c2 = cv[i] * cv[i]; shape[i] = 1.0 / c2; scale[i] = c2 * mean[i]; } } MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood( new MultivariateGammaDistribution(shape, scale) ); XMLObject cxo = xo.getChild(DATA); for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof Parameter) { Parameter data = (Parameter) cxo.getChild(j); likelihood.addData(data); if (data.getDimension() != shape.length) throw new XMLParseException("dim(" + data.getStatisticName() + ") != " + shape.length + " in " + xo.getName() + "element"); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new XORRule( new ElementRule(MVGAMMA_SHAPE, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(MVN_MEAN, new XMLSyntaxRule[]{new ElementRule(Parameter.class)})), new XORRule( new ElementRule(MVGAMMA_SCALE, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(MVN_CV, new XMLSyntaxRule[]{new ElementRule(Parameter.class)})), new ElementRule(DATA, new XMLSyntaxRule[]{new ElementRule(Parameter.class, 1, Integer.MAX_VALUE)}) }; public String getParserDescription() { return "Calculates the likelihood of some data under a given multivariate-gamma distribution."; } public Class getReturnType() { return MultivariateDistributionLikelihood.class; } }; public static XMLObjectParser TREE_TRAIT_MODEL = new AbstractXMLObjectParser() { public String getParserName() { return TREE_TRAIT_NORMAL; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean conditionOnRoot = xo.getAttribute(CONDITION, false); FullyConjugateMultivariateTraitLikelihood traitModel = (FullyConjugateMultivariateTraitLikelihood) xo.getChild(FullyConjugateMultivariateTraitLikelihood.class); TreeTraitNormalDistributionModel treeTraitModel; if (xo.getChild(ROOT_VALUE) != null) { XMLObject cxo = xo.getChild(ROOT_VALUE); Parameter rootValue = (Parameter) cxo.getChild(Parameter.class); treeTraitModel = new TreeTraitNormalDistributionModel(traitModel, rootValue, conditionOnRoot); } else { treeTraitModel = new TreeTraitNormalDistributionModel(traitModel, conditionOnRoot); } return treeTraitModel; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newBooleanRule(CONDITION, true), new ElementRule(FullyConjugateMultivariateTraitLikelihood.class) }; public String getParserDescription() { return "Parses TreeTraitNormalDistributionModel"; } public Class getReturnType() { return TreeTraitNormalDistributionModel.class; } }; public static XMLObjectParser TREE_TRAIT_DISTRIBUTION = new AbstractXMLObjectParser() { public String getParserName() { return TREE_TRAIT; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { /* boolean conditionOnRoot = xo.getAttribute(CONDITION, false); FullyConjugateMultivariateTraitLikelihood traitModel = (FullyConjugateMultivariateTraitLikelihood) xo.getChild(FullyConjugateMultivariateTraitLikelihood.class); */ TreeTraitNormalDistributionModel treeTraitModel = (TreeTraitNormalDistributionModel) xo.getChild(TreeTraitNormalDistributionModel.class); MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood( // new TreeTraitNormalDistributionModel(traitModel, conditionOnRoot) treeTraitModel ); XMLObject cxo = xo.getChild(DATA); for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof Parameter) { likelihood.addData((Parameter) cxo.getChild(j)); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element " + cxo.getName()); } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { // AttributeRule.newBooleanRule(CONDITION, true), // new ElementRule(FullyConjugateMultivariateTraitLikelihood.class), new ElementRule(TreeTraitNormalDistributionModel.class), new ElementRule(DATA, new XMLSyntaxRule[]{new ElementRule(Parameter.class, 1, Integer.MAX_VALUE)}) }; public String getParserDescription() { return "Calculates the likelihood of some data under a given multivariate-gamma distribution."; } public Class getReturnType() { return MultivariateDistributionLikelihood.class; } }; }