/*
* AbstractMultivariateTraitLikelihood.java
*
* Copyright (c) 2002-2013 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.continuous;
import dr.evolution.tree.*;
import dr.evolution.util.Taxon;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.StrictClockBranchRates;
import dr.evomodel.tree.TreeModel;
import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities;
import dr.inference.distribution.MultivariateDistributionLikelihood;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.*;
import dr.math.distributions.MultivariateDistribution;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.stats.DiscreteStatistics;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.CommonCitations;
import dr.xml.*;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import java.util.*;
import java.util.logging.Logger;
/**
* @author Marc Suchard
*/
public abstract class AbstractMultivariateTraitLikelihood extends AbstractModelLikelihood
implements TreeTraitProvider, Citable {
public static final String TRAIT_LIKELIHOOD = "multivariateTraitLikelihood";
public static final String CONJUGATE_ROOT_PRIOR = "conjugateRootPrior";
public static final String MODEL = "diffusionModel";
public static final String TREE = "tree";
public static final String CACHE_BRANCHES = "cacheBranches";
public static final String REPORT_MULTIVARIATE = "reportAsMultivariate";
public static final String CHECK = "check";
public static final String USE_TREE_LENGTH = "useTreeLength";
public static final String SCALE_BY_TIME = "scaleByTime";
public static final String SUBSTITUTIONS = "substitutions";
public static final String SAMPLING_DENSITY = "samplingDensity";
public static final String INTEGRATE = "integrateInternalTraits";
public static final String STANDARDIZE_TRAITS = "standardizeTraits";
public static final String RECIPROCAL_RATES = "reciprocalRates";
public static final String PRIOR_SAMPLE_SIZE = "priorSampleSize";
public static final String RANDOM_SAMPLE = "randomSample";
public static final String IGNORE_PHYLOGENY = "ignorePhylogeny";
public static final String ASCERTAINMENT = "ascertainedTaxon";
public static final String EXCHANGEABLE_TIPS = "exchangeableTips";
public static final String DRIFT_MODELS = "driftModels";
private BranchRateModel branchRateModel;
public static final String STRENGTH_OF_SELECTION = "strengthOfSelection";
public static final String OPTIMAL_TRAITS = "optimalTraits";
// public AbstractMultivariateTraitLikelihood(String traitName,
// MultivariateTraitTree treeModel,
// MultivariateDiffusionModel diffusionModel,
// CompoundParameter traitParameter,
// List<Integer> missingIndices,
// boolean cacheBranches,
// boolean scaleByTime,
// boolean useTreeLength,
// BranchRateModel rateModel,
// Model samplingDensity,
// boolean reportAsMultivariate,
// boolean reciprocalRates) {
// this(traitName, treeModel, diffusionModel, traitParameter, null, missingIndices, cacheBranches,
// scaleByTime, useTreeLength, rateModel, null, samplingDensity, reportAsMultivariate, reciprocalRates);
// }
public AbstractMultivariateTraitLikelihood(String traitName,
MultivariateTraitTree treeModel,
MultivariateDiffusionModel diffusionModel,
CompoundParameter traitParameter,
Parameter deltaParameter,
List<Integer> missingIndices,
boolean cacheBranches,
boolean scaleByTime,
boolean useTreeLength,
BranchRateModel rateModel,
List<BranchRateModel> driftModels,
List<BranchRateModel> optimalValues,
BranchRateModel strengthOfSelection,
Model samplingDensity,
boolean reportAsMultivariate,
boolean reciprocalRates) {
super(TRAIT_LIKELIHOOD);
this.traitName = traitName;
this.treeModel = treeModel;
this.branchRateModel = rateModel;
this.driftModels = driftModels;
this.optimalValues = optimalValues;
this.strengthOfSelection = strengthOfSelection;
this.diffusionModel = diffusionModel;
this.traitParameter = traitParameter;
this.missingIndices = missingIndices;
addModel(treeModel);
addModel(diffusionModel);
this.deltaParameter = deltaParameter;
if (deltaParameter != null) {
addVariable(deltaParameter);
}
if (rateModel != null) {
hasBranchRateModel = true;
addModel(rateModel);
}
if (driftModels != null) {
for (BranchRateModel drift : driftModels) {
addModel(drift);
}
}
if (optimalValues != null) {
for (BranchRateModel optVal : optimalValues) {
addModel(optVal);
}
}
if (strengthOfSelection != null) {
addModel(strengthOfSelection);
}
if (samplingDensity != null) {
addModel(samplingDensity);
}
if (traitParameter != null)
addVariable(traitParameter);
this.reportAsMultivariate = reportAsMultivariate;
this.cacheBranches = cacheBranches;
if (cacheBranches) {
cachedLogLikelihoods = new double[treeModel.getNodeCount()];
storedCachedLogLikelihood = new double[treeModel.getNodeCount()];
validLogLikelihoods = new boolean[treeModel.getNodeCount()];
storedValidLogLikelihoods = new boolean[treeModel.getNodeCount()];
}
this.scaleByTime = scaleByTime;
this.useTreeLength = useTreeLength;
this.reciprocalRates = reciprocalRates;
dimTrait = diffusionModel.getPrecisionmatrix().length;
dim = traitParameter != null ? traitParameter.getParameter(0).getDimension() : 0;
numData = dim / dimTrait;
if (dim % dimTrait != 0)
throw new RuntimeException("dim is not divisible by dimTrait");
recalculateTreeLength();
printInformtion();
}
// public AbstractMultivariateTraitLikelihood(String traitName,
// MultivariateTraitTree treeModel,
// MultivariateDiffusionModel diffusionModel,
// CompoundParameter traitParameter,
// Parameter deltaParameter,
// List<Integer> missingIndices,
// boolean cacheBranches,
// boolean scaleByTime,
// boolean useTreeLength,
// BranchRateModel rateModel,
// List<BranchRateModel> optimalValues,
// BranchRateModel strengthOfSelection,
// Model samplingDensity,
// boolean reportAsMultivariate,
// boolean reciprocalRates) {
//
// super(TRAIT_LIKELIHOOD);
//
// this.traitName = traitName;
// this.treeModel = treeModel;
// this.branchRateModel = rateModel;
// this.optimalValues = optimalValues;
// this.strengthOfSelection = strengthOfSelection;
// this.diffusionModel = diffusionModel;
// this.traitParameter = traitParameter;
// this.missingIndices = missingIndices;
// addModel(treeModel);
// addModel(diffusionModel);
//
// this.deltaParameter = deltaParameter;
// if (deltaParameter != null) {
// addVariable(deltaParameter);
// }
//
//
// if (rateModel != null) {
// hasBranchRateModel = true;
// addModel(rateModel);
// }
//
// if (optimalValues != null) {
// for (BranchRateModel optVal : optimalValues) {
// addModel(optVal);
// }
// }
//
// if (strengthOfSelection != null) {
// addModel(strengthOfSelection);
// }
//
// if (samplingDensity != null) {
// addModel(samplingDensity);
// }
//
// if (traitParameter != null)
// addVariable(traitParameter);
//
// this.reportAsMultivariate = reportAsMultivariate;
//
// this.cacheBranches = cacheBranches;
// if (cacheBranches) {
// cachedLogLikelihoods = new double[treeModel.getNodeCount()];
// storedCachedLogLikelihood = new double[treeModel.getNodeCount()];
// validLogLikelihoods = new boolean[treeModel.getNodeCount()];
// storedValidLogLikelihoods = new boolean[treeModel.getNodeCount()];
// }
//
// this.scaleByTime = scaleByTime;
// this.useTreeLength = useTreeLength;
// this.reciprocalRates = reciprocalRates;
//
// dimTrait = diffusionModel.getPrecisionmatrix().length;
// dim = traitParameter != null ? traitParameter.getParameter(0).getDimension() : 0;
// numData = dim / dimTrait;
//
// if (dim % dimTrait != 0)
// throw new RuntimeException("dim is not divisible by dimTrait");
//
// recalculateTreeLength();
// printInformtion();
//
// }
protected void printInformtion() {
StringBuffer sb = new StringBuffer("Creating multivariate diffusion model:\n");
sb.append("\tTrait: ").append(traitName).append("\n");
sb.append("\tDiffusion process: ").append(diffusionModel.getId()).append("\n");
sb.append("\tHeterogenity model: ").append(branchRateModel != null ? branchRateModel.getId() : "homogeneous").append("\n");
sb.append("\tTree normalization: ").append(scaleByTime ? (useTreeLength ? "length" : "height") : "off").append("\n");
sb.append("\tUsing reciprocal (precision) rates: ").append(reciprocalRates).append("\n");
if (scaleByTime) {
recalculateTreeLength();
if (useTreeLength) {
sb.append("\tInitial tree length: ").append(treeLength).append("\n");
} else {
sb.append("\tInitial tree height: ").append(treeLength).append("\n");
}
}
sb.append(extraInfo());
sb.append("\tPlease cite:\n");
sb.append(Citable.Utils.getCitationString(this));
sb.append("\n\tDiffusion dimension : ").append(dimTrait).append("\n");
sb.append("\tNumber of observations: ").append(numData).append("\n");
Logger.getLogger("dr.evomodel").info(sb.toString());
}
@Override
public Citation.Category getCategory() {
return Citation.Category.TRAIT_MODELS;
}
@Override
public String getDescription() {
return "Multivariate Diffusion model";
}
@Override
public List<Citation> getCitations() {
List<Citation> citations = new ArrayList<Citation>();
citations.add(CommonCitations.LEMEY_2010_PHYLOGEOGRAPHY);
if (doAscertainmentCorrect) {
citations.add(
new Citation(
new Author[]{
new Author("MA", "Suchard"),
new Author("J", "Novembre"),
new Author("B", "von Holdt"),
new Author("G", "Cybis"),
},
Citation.Status.IN_PREPARATION
)
);
}
return citations;
}
protected abstract String extraInfo();
public CompoundParameter getTraitParameter() {
return traitParameter;
}
public void setAscertainedTaxon(Taxon taxon) {
ascertainedTaxonIndex = treeModel.getTaxonIndex(taxon);
if (ascertainedTaxonIndex == -1) {
throw new RuntimeException("Taxon " + taxon.getId() + " is not in tree " + treeModel.getId());
}
doAscertainmentCorrect = true;
StringBuilder sb = new StringBuilder("Enabling ascertainment correction for multivariate trait model: ");
sb.append(getId()).append("\n");
sb.append("\tTaxon: ").append(taxon.getId()).append("\n");
Logger.getLogger("dr.evomodel").info(sb.toString());
}
public double[] getShiftForBranchLength(NodeRef node) {
if (driftModels != null) {
final int dim = driftModels.size();
double[] drift = new double[dim];
double realTimeBranchLength = treeModel.getBranchLength(node);
for (int i = 0; i < dim; ++i) {
drift[i] = driftModels.get(i).getBranchRate(treeModel, node) * realTimeBranchLength;
}
return drift;
} else {
throw new RuntimeException("getShiftForBranchLength should not be called.");
}
// But really should get values from driftModel.getBranchRate(treeModel, node);
}
public double[] getOptimalValue(NodeRef node) {
if (optimalValues != null) {
final int dim = optimalValues.size();
double[] optVals = new double[dim];
for (int i = 0; i < dim; ++i) {
optVals[i] = optimalValues.get(i).getBranchRate(treeModel, node);
}
return optVals;
} else {
throw new RuntimeException("getOptimalValue should not be called.");
}
}
public double getTimeScaledSelection(NodeRef node) {
if (strengthOfSelection != null) {
double selection;
double realTimeBranchLength = treeModel.getBranchLength(node);
selection = strengthOfSelection.getBranchRate(treeModel, node) * realTimeBranchLength;
return selection;
} else {
throw new RuntimeException("getTimeScaledSelection should not be called.");
}
}
protected double rescaleLength(double length) {
if (scaleByTime) {
length /= treeLength;
}
return length;
}
public double getRescaledBranchLengthForPrecision(NodeRef node) {
double length = treeModel.getBranchLength(node);
if (hasBranchRateModel) {
if (reciprocalRates) {
length /= branchRateModel.getBranchRate(treeModel, node); // branch rate scales as precision (inv-time)
} else {
length *= branchRateModel.getBranchRate(treeModel, node); // branch rate scales as variance (time)
}
}
// if (scaleByTime) {
// length /= treeLength;
// }
length = rescaleLength(length);
if (deltaParameter != null && treeModel.isExternal(node)) {
length += deltaParameter.getParameterValue(0);
}
//System.err.println("Node Number: " + node.getNumber());
//System.err.println("Trait value" + traitParameter.getParameterValue(0));
//System.err.println("Trait value" + traitParameter.getParameterValue(1));
// System.err.println("Trait value" + traitParameter.getParameterValue(2));
// System.err.println("Trait value" + traitParameter.getParameterValue(3));
// System.err.println("branch length: " + treeModel.getBranchLength(node));
// System.err.println("rate: " + branchRateModel.getBranchRate(treeModel,node));
return length;
}
// **************************************************************
// ModelListener IMPLEMENTATION
// **************************************************************
protected void handleModelChangedEvent(Model model, Object object, int index) {
if (!cacheBranches) {
likelihoodKnown = false;
updateRestrictedNodePartials = true;
if (model == treeModel)
recalculateTreeLength();
return;
}
if (model == diffusionModel) {
updateAllNodes();
}
// fireTreeEvents sends two events here when a node trait is changed,
// ignoring object instance Parameter case
else if (model == treeModel) {
if (object instanceof TreeModel.TreeChangedEvent) {
TreeModel.TreeChangedEvent event = (TreeModel.TreeChangedEvent) object;
if (event.isTreeChanged()) {
recalculateTreeLength();
updateAllNodes();
updateRestrictedNodePartials = true;
} else if (event.isHeightChanged()) {
recalculateTreeLength();
if (useTreeLength || (scaleByTime && treeModel.isRoot(event.getNode())))
updateAllNodes();
else {
updateNodeAndChildren(event.getNode());
}
} else if (event.isNodeParameterChanged()) {
updateNodeAndChildren(event.getNode());
} else if (event.isNodeChanged()) {
recalculateTreeLength();
if (useTreeLength || (scaleByTime && treeModel.isRoot(event.getNode())))
updateAllNodes();
else {
updateNodeAndChildren(event.getNode());
}
updateRestrictedNodePartials = true;
} else {
throw new RuntimeException("Unexpected TreeModel TreeChangedEvent occurring in AbstractMultivariateTraitLikelihood");
}
} else if (object instanceof Parameter) {
// Ignoring
} else {
throw new RuntimeException("Unexpected object throwing events in AbstractMultivariateTraitLikelihood");
}
} else if (model == branchRateModel) {
if (index == -1) {
updateAllNodes();
} else {
if (object == null || ((Parameter) object).getDimension() == 2 * (treeModel.getNodeCount() - 1))
updateNode(treeModel.getNode(index)); // This is a branch specific update
else
updateAllNodes(); // Probably an epoch model
}
} else if (model instanceof RestrictedPartials) {
updateAllNodes();
updateRestrictedNodePartials = true;
} else {
throw new RuntimeException("Unknown componentChangedEvent");
}
}
protected void updateAllNodes() {
for (int i = 0; i < treeModel.getNodeCount(); i++)
validLogLikelihoods[i] = false;
likelihoodKnown = false;
}
private void updateNode(NodeRef node) {
validLogLikelihoods[node.getNumber()] = false;
likelihoodKnown = false;
}
private void updateNodeAndChildren(NodeRef node) {
validLogLikelihoods[node.getNumber()] = false;
for (int i = 0; i < treeModel.getChildCount(node); i++)
validLogLikelihoods[treeModel.getChild(node, i).getNumber()] = false;
likelihoodKnown = false;
}
protected double getTreeLength() {
double treeLength = 0;
for (int i = 0; i < treeModel.getNodeCount(); i++) {
NodeRef node = treeModel.getNode(i);
if (!treeModel.isRoot(node))
treeLength += treeModel.getBranchLength(node); // Bug was here
}
return treeLength;
}
public void recalculateTreeLength() {
if (!scaleByTime)
return;
if (useTreeLength) {
treeLength = getTreeLength();
} else { // Normalizing by tree height.
treeLength = treeModel.getNodeHeight(treeModel.getRoot());
}
}
public BranchRateModel getBranchRateModel() {
return branchRateModel;
}
// **************************************************************
// VariableListener IMPLEMENTATION
// **************************************************************
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
if (variable == deltaParameter) {
likelihoodKnown = false;
}
if (variable == traitParameter) {
likelihoodKnown = false;
}
// All parameter changes are handled first by the treeModel
if (!cacheBranches)
likelihoodKnown = false;
}
// **************************************************************
// Model IMPLEMENTATION
// **************************************************************
/**
* Stores the precalculated state: in this case the intervals
*/
protected void storeState() {
storedLikelihoodKnown = likelihoodKnown;
storedLogLikelihood = logLikelihood;
storedTreeLength = treeLength;
if (cacheBranches) {
System.arraycopy(cachedLogLikelihoods, 0, storedCachedLogLikelihood, 0, treeModel.getNodeCount());
System.arraycopy(validLogLikelihoods, 0, storedValidLogLikelihoods, 0, treeModel.getNodeCount());
}
}
/**
* Restores the precalculated state: that is the intervals of the tree.
*/
protected void restoreState() {
likelihoodKnown = storedLikelihoodKnown;
logLikelihood = storedLogLikelihood;
treeLength = storedTreeLength;
if (cacheBranches) {
double[] tmp = storedCachedLogLikelihood;
storedCachedLogLikelihood = cachedLogLikelihoods;
cachedLogLikelihoods = tmp;
boolean[] tmp2 = storedValidLogLikelihoods;
storedValidLogLikelihoods = validLogLikelihoods;
validLogLikelihoods = tmp2;
}
updateRestrictedNodePartials = true; // TODO remove or cache? Caching is still not working, see IMTL.restoreState()
}
protected void acceptState() {
} // nothing to do
public MultivariateTraitTree getTreeModel() {
return treeModel;
}
public String getTraitName() {
return traitName;
}
public MultivariateDiffusionModel getDiffusionModel() {
return diffusionModel;
}
// public boolean getInSubstitutionTime() {
// return inSubstitutionTime;
// }
// **************************************************************
// Likelihood IMPLEMENTATION
// **************************************************************
public Model getModel() {
return this;
}
public String toString() {
return getClass().getName() + "(" + getLogLikelihood() + ")";
}
public final double getLogLikelihood() {
if (!likelihoodKnown) {
logLikelihood = calculateLogLikelihood();
if (doAscertainmentCorrect) {
double correction = calculateAscertainmentCorrection(ascertainedTaxonIndex);
// System.err.println("Correction = " + correction);
logLikelihood -= correction;
}
likelihoodKnown = true;
}
return logLikelihood;
}
protected abstract double calculateAscertainmentCorrection(int taxonIndex);
public abstract double getLogDataLikelihood();
public void makeDirty() {
likelihoodKnown = false;
if (cacheBranches)
updateAllNodes();
}
public LogColumn[] getColumns() {
return new LogColumn[]{
new LikelihoodColumn(getId() + ".joint"),
new NumberColumn(getId() + ".data") {
public double getDoubleValue() {
return getLogDataLikelihood();
}
}
};
}
public abstract double calculateLogLikelihood();
// public double getMaxLogLikelihood() {
// return maxLogLikelihood;
// }
// **************************************************************
// Loggable IMPLEMENTATION
// **************************************************************
private TreeTrait[] treeTraits = null;
public TreeTrait[] getTreeTraits() {
if (treeTraits == null) {
final double[] trait = getRootNodeTrait();
if (trait.length == 1 || reportAsMultivariate) {
treeTraits = new TreeTrait[]{
new TreeTrait.DA() {
public String getTraitName() {
return traitName;
}
public Intent getIntent() {
return Intent.NODE;
}
public Class getTraitClass() {
return Double.class;
}
public double[] getTrait(Tree tree, NodeRef node) {
return getTraitForNode(tree, node, traitName);
}
}
};
} else {
throw new RuntimeException("Reporting of traits is only supported as multivariate");
}
}
return treeTraits;
}
public TreeTrait getTreeTrait(String key) {
TreeTrait[] tts = getTreeTraits();
for (TreeTrait tt : tts) {
if (tt.getTraitName().equals(key)) {
return tt;
}
}
return null;
}
public final int getNumData() {
return numData;
}
public final int getDimTrait() {
return dimTrait;
}
protected double[] getRootNodeTrait() {
return treeModel.getMultivariateNodeTrait(treeModel.getRoot(), traitName);
}
public abstract double[] getTraitForNode(Tree tree, NodeRef node, String traitName);
public void check(Parameter trait) throws XMLParseException {
diffusionModel.check(trait);
}
// **************************************************************
// XMLElement IMPLEMENTATION
// **************************************************************
public Element createElement(Document d) {
throw new RuntimeException("Not implemented yet!");
}
// **************************************************************
// XMLObjectParser
// **************************************************************
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public String getParserName() {
return TRAIT_LIKELIHOOD;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
MultivariateDiffusionModel diffusionModel = (MultivariateDiffusionModel) xo.getChild(MultivariateDiffusionModel.class);
MultivariateTraitTree treeModel = (MultivariateTraitTree) xo.getChild(MultivariateTraitTree.class);
boolean cacheBranches = xo.getAttribute(CACHE_BRANCHES, true);
boolean integrate = xo.getAttribute(INTEGRATE, false);
boolean useTreeLength = xo.getAttribute(USE_TREE_LENGTH, false);
boolean scaleByTime = xo.getAttribute(SCALE_BY_TIME, false);
boolean reciprocalRates = xo.getAttribute(RECIPROCAL_RATES, false);
boolean reportAsMultivariate = xo.getAttribute(REPORT_MULTIVARIATE, true);
boolean standardizeTraits = xo.getAttribute(STANDARDIZE_TRAITS, false);
BranchRateModel rateModel = (BranchRateModel) xo.getChild(BranchRateModel.class);
List<BranchRateModel> driftModels = null;
if (xo.hasChildNamed(DRIFT_MODELS)) {
driftModels = new ArrayList<BranchRateModel>();
XMLObject cxo = xo.getChild(DRIFT_MODELS);
final int number = cxo.getChildCount();
if (number != diffusionModel.getPrecisionmatrix().length) {
throw new XMLParseException("Wrong number of drift models (" + number + ") for a trait of" +
" dimension " + diffusionModel.getPrecisionmatrix().length + " in " + xo.getId()
);
}
for (int i = 0; i < number; ++i) {
driftModels.add((BranchRateModel) cxo.getChild(i));
}
}
List<BranchRateModel> optimalValues = null;
BranchRateModel strengthOfSelection = null;
if (xo.hasChildNamed(OPTIMAL_TRAITS)) {
optimalValues = new ArrayList<BranchRateModel>();
XMLObject cxo = xo.getChild(OPTIMAL_TRAITS);
final int numberModels = cxo.getChildCount();
if (numberModels != diffusionModel.getPrecisionmatrix().length) {
throw new XMLParseException("Wrong number of optimal trait models (" + numberModels + ") for a trait of" +
" dimension " + diffusionModel.getPrecisionmatrix().length + " in " + xo.getId()
);
}
for (int i = 0; i < numberModels; ++i) {
optimalValues.add((BranchRateModel) cxo.getChild(i));
}
}
if (xo.hasChildNamed(STRENGTH_OF_SELECTION)) {
XMLObject cxo = xo.getChild(STRENGTH_OF_SELECTION);
strengthOfSelection = (BranchRateModel) cxo.getChild(BranchRateModel.class);
}
TreeTraitParserUtilities utilities = new TreeTraitParserUtilities();
String traitName = TreeTraitParserUtilities.DEFAULT_TRAIT_NAME;
TreeTraitParserUtilities.TraitsAndMissingIndices returnValue =
utilities.parseTraitsFromTaxonAttributes(xo, traitName, treeModel, integrate);
CompoundParameter traitParameter = returnValue.traitParameter;
List<Integer> missingIndices = returnValue.missingIndices;
traitName = returnValue.traitName;
/* TODO Add partially integrated traits here */
Model samplingDensity = null;
if (xo.hasChildNamed(SAMPLING_DENSITY)) {
XMLObject cxo = xo.getChild(SAMPLING_DENSITY);
samplingDensity = (Model) cxo.getChild(Model.class);
}
Parameter deltaParameter = null;
if (xo.hasChildNamed("delta")) {
XMLObject cxo = xo.getChild("delta");
deltaParameter = (Parameter) cxo.getChild(Parameter.class);
}
if (standardizeTraits) {
// standardize(traitParameter);
// dimTrait = diffusionModel.getPrecisionmatrix().length;
// dim = traitParameter != null ? traitParameter.getParameter(0).getDimension() : 0;
// numData = dim / dimTrait;
// System.err.println(traitParameter.getDimension());
// System.err.println(traitParameter.getParameterCount());
// System.err.println(traitParameter.getParameter(0).getDimension());
// System.exit(-1);
int numTraits = traitParameter.getParameter(0).getDimension();
int numObservations = traitParameter.getParameterCount();
StringBuilder sb = new StringBuilder();
sb.append("Traits have been standardized. Use following to transform values back to original scale.\n");
for (int trait = 0; trait < numTraits; ++trait) {
double[] values = new double[numObservations];
for (int obs = 0; obs < numObservations; ++obs) {
values[obs] = traitParameter.getParameter(obs).getParameterValue(trait);
}
double traitMean = DiscreteStatistics.mean(values);
double traitSD = Math.sqrt(DiscreteStatistics.variance(values, traitMean));
sb.append("\tDimension " + (trait + 1) + ": multiply by " + traitSD + " then add " + traitMean + "\n");
for (int obs = 0; obs < numObservations; ++obs) {
traitParameter.getParameter(obs).setParameterValue(trait,
(values[obs] - traitMean) / traitSD);
}
}
Logger.getLogger("dr.evomodel").info(sb.toString());
}
List<RestrictedPartials> restrictedPartialsList = null;
for (int i = 0; i < xo.getChildCount(); ++i) {
Object cxo = xo.getChild(i);
if (cxo instanceof RestrictedPartials) {
if (!integrate) {
throw new XMLParseException("Restricted partials are currently only implements" +
"for integrated multivariate trait likelihood models");
}
if (restrictedPartialsList == null) {
restrictedPartialsList = new ArrayList<RestrictedPartials>();
}
restrictedPartialsList.add((RestrictedPartials) cxo);
}
}
AbstractMultivariateTraitLikelihood like;
if (integrate) {
MultivariateDistributionLikelihood rootPrior =
(MultivariateDistributionLikelihood) xo.getChild(MultivariateDistributionLikelihood.class);
if (rootPrior != null) {
if (!(rootPrior.getDistribution() instanceof MultivariateDistribution))
throw new XMLParseException("Only multivariate normal priors allowed for Gibbs sampling the root trait");
MultivariateNormalDistribution rootDistribution =
(MultivariateNormalDistribution) rootPrior.getDistribution();
like = new SemiConjugateMultivariateTraitLikelihood(traitName, treeModel, diffusionModel,
traitParameter, missingIndices, cacheBranches,
scaleByTime, useTreeLength, rateModel, samplingDensity, reportAsMultivariate,
rootDistribution, reciprocalRates, restrictedPartialsList);
// like = new DebugableIntegratedMultivariateTraitLikelihood(traitName, treeModel, diffusionModel,
// traitParameter, missingIndices, cacheBranches,
// scaleByTime, useTreeLength, rateModel, samplingDensity, reportAsMultivariate,
// rootDistribution, reciprocalRates);
} else {
XMLObject cxo = xo.getChild(CONJUGATE_ROOT_PRIOR);
if (cxo == null) {
throw new XMLParseException("Must specify a conjugate or multivariate normal root prior");
}
boolean ignorePhylogeny = xo.getAttribute(IGNORE_PHYLOGENY, false);
Parameter meanParameter = (Parameter) cxo.getChild(MultivariateDistributionLikelihood.MVN_MEAN)
.getChild(Parameter.class);
if (meanParameter.getDimension() != diffusionModel.getPrecisionmatrix().length) {
throw new XMLParseException("Root prior mean dimension does not match trait diffusion dimension");
}
Parameter sampleSizeParameter = (Parameter) cxo.getChild(PRIOR_SAMPLE_SIZE).getChild(Parameter.class);
double[] mean = meanParameter.getParameterValues();
double pseudoObservations = sampleSizeParameter.getParameterValue(0);
if (ignorePhylogeny) {
boolean exchangeableTips = xo.getAttribute(EXCHANGEABLE_TIPS, true);
like = new NonPhylogeneticMultivariateTraitLikelihood(traitName, treeModel, diffusionModel,
traitParameter, deltaParameter, missingIndices, cacheBranches,
scaleByTime, useTreeLength, rateModel, samplingDensity, reportAsMultivariate,
mean, pseudoObservations, restrictedPartialsList, reciprocalRates, exchangeableTips);
} else {
if (driftModels == null) {
if (strengthOfSelection == null) {
like = new FullyConjugateMultivariateTraitLikelihood(traitName, treeModel, diffusionModel,
traitParameter, deltaParameter, missingIndices, cacheBranches,
scaleByTime, useTreeLength,
rateModel, null, null, null,
samplingDensity, reportAsMultivariate,
mean, restrictedPartialsList, pseudoObservations, reciprocalRates);
} else {
like = new FullyConjugateMultivariateTraitLikelihood(traitName, treeModel, diffusionModel,
traitParameter, deltaParameter, missingIndices, cacheBranches,
scaleByTime, useTreeLength,
rateModel, null, optimalValues, strengthOfSelection,
samplingDensity, reportAsMultivariate,
mean, restrictedPartialsList,pseudoObservations, reciprocalRates);
}
} else {
like = new FullyConjugateMultivariateTraitLikelihood(traitName, treeModel, diffusionModel,
traitParameter, deltaParameter, missingIndices, cacheBranches,
scaleByTime, useTreeLength,
rateModel, driftModels, null, null,
samplingDensity, reportAsMultivariate,
mean, restrictedPartialsList, pseudoObservations, reciprocalRates);
}
}
}
} else {
like = new SampledMultivariateTraitLikelihood(traitName, treeModel, diffusionModel,
traitParameter, missingIndices, cacheBranches,
scaleByTime, useTreeLength, rateModel, samplingDensity, reportAsMultivariate,
reciprocalRates);
}
if (!integrate && xo.hasChildNamed(TreeTraitParserUtilities.RANDOMIZE)) {
utilities.randomize(xo);
}
if (xo.hasChildNamed(TreeTraitParserUtilities.JITTER)) {
utilities.jitter(xo, diffusionModel.getPrecisionmatrix().length, missingIndices);
}
if (xo.hasChildNamed(CHECK)) {
XMLObject cxo = xo.getChild(CHECK);
Parameter check = (Parameter) cxo.getChild(Parameter.class);
like.check(check);
}
boolean isRRW = (rateModel != null) && (!(rateModel instanceof StrictClockBranchRates));
if (!xo.hasAttribute(TreeTraitParserUtilities.ALLOW_IDENTICAL) &&
isRRW &&
utilities.hasIdenticalTraits(traitParameter, missingIndices, diffusionModel.getPrecisionmatrix().length)) {
throw new XMLParseException("For multivariate trait analyses, all trait values should be unique.\n" +
"Check data or add random noise using 'jitter' option.");
}
if (xo.hasChildNamed(ASCERTAINMENT)) {
XMLObject cxo = xo.getChild(ASCERTAINMENT);
Taxon taxon = (Taxon) cxo.getChild(Taxon.class);
if (!integrate) {
throw new XMLParseException("Ascertainment correction is currently only implemented" +
" for integrated multivariate trait likelihood models");
}
like.setAscertainedTaxon(taxon);
}
return like;
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "Provides the likelihood of a continuous trait evolving on a tree by a " +
"given diffusion model.";
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
new StringAttributeRule(TreeTraitParserUtilities.TRAIT_NAME, "The name of the trait for which a likelihood should be calculated"),
new ElementRule(TreeTraitParserUtilities.TRAIT_PARAMETER, new XMLSyntaxRule[]{
new ElementRule(Parameter.class)
}),
new ElementRule("delta", new XMLSyntaxRule[]{
new ElementRule(Parameter.class)
}, true),
AttributeRule.newBooleanRule(INTEGRATE, true),
// new XORRule(
new ElementRule(MultivariateDistributionLikelihood.class, true),
new ElementRule(CONJUGATE_ROOT_PRIOR, new XMLSyntaxRule[]{
new ElementRule(MultivariateDistributionLikelihood.MVN_MEAN,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(PRIOR_SAMPLE_SIZE,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
}, true),
// true),
new ElementRule(ASCERTAINMENT, new XMLSyntaxRule[]{
new ElementRule(Taxon.class)
}, true),
new ElementRule(MultivariateDiffusionModel.class),
new ElementRule(MultivariateTraitTree.class),
new ElementRule(BranchRateModel.class, true),
AttributeRule.newDoubleArrayRule("cut", true),
AttributeRule.newBooleanRule(REPORT_MULTIVARIATE, true),
AttributeRule.newBooleanRule(USE_TREE_LENGTH, true),
AttributeRule.newBooleanRule(SCALE_BY_TIME, true),
AttributeRule.newBooleanRule(RECIPROCAL_RATES, true),
AttributeRule.newBooleanRule(CACHE_BRANCHES, true),
AttributeRule.newIntegerRule(RANDOM_SAMPLE, true),
AttributeRule.newBooleanRule(IGNORE_PHYLOGENY, true),
AttributeRule.newBooleanRule(EXCHANGEABLE_TIPS, true),
AttributeRule.newBooleanRule(TreeTraitParserUtilities.SAMPLE_MISSING_TRAITS, true),
new ElementRule(Parameter.class, true),
TreeTraitParserUtilities.randomizeRules(true),
TreeTraitParserUtilities.jitterRules(true),
new ElementRule(CHECK, new XMLSyntaxRule[]{
new ElementRule(Parameter.class)
}, true),
new ElementRule(DRIFT_MODELS, new XMLSyntaxRule[]{
new ElementRule(BranchRateModel.class, 1, Integer.MAX_VALUE),
}, true),
new ElementRule(RestrictedPartials.class, 0, Integer.MAX_VALUE),
};
public Class getReturnType() {
return AbstractMultivariateTraitLikelihood.class;
}
};
protected void addRestrictedPartials(RestrictedPartials restrictedPartials) {
throw new IllegalArgumentException("Not implemented for this model type");
}
MultivariateTraitTree treeModel = null;
MultivariateDiffusionModel diffusionModel = null;
String traitName = null;
CompoundParameter traitParameter;
List<Integer> missingIndices;
protected double logLikelihood;
protected double maxLogLikelihood = Double.NEGATIVE_INFINITY;
private double storedLogLikelihood;
protected boolean likelihoodKnown = false;
private boolean storedLikelihoodKnown = false;
protected List<BranchRateModel> driftModels = null;
protected List<BranchRateModel> optimalValues = null;
protected BranchRateModel strengthOfSelection = null;
private boolean hasBranchRateModel = false;
private double treeLength;
private double storedTreeLength;
private final boolean reportAsMultivariate;
private final boolean scaleByTime;
private final boolean useTreeLength;
private final boolean reciprocalRates;
protected boolean cacheBranches;
protected double[] cachedLogLikelihoods;
protected double[] storedCachedLogLikelihood;
protected boolean[] validLogLikelihoods;
protected boolean[] storedValidLogLikelihoods;
private final Parameter deltaParameter;
private boolean doAscertainmentCorrect = false;
private int ascertainedTaxonIndex;
protected int numData;
protected int dimTrait;
protected int dim;
protected boolean updateRestrictedNodePartials = true;
protected boolean savedUpdateRestrictedNodePartials;
// protected Map<BitSet, RestrictedPartials> restrictedPartialsMap;
}