/*
* MultinomialLatentLiabilityLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.continuous;
import dr.evolution.alignment.PatternList;
import dr.evolution.tree.NodeRef;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.*;
import dr.math.distributions.Distribution;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.CommonCitations;
import dr.xml.*;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
public class MultinomialLatentLiabilityLikelihood extends AbstractModelLikelihood implements LatentTruncation, Citable, SoftThresholdLikelihood {
public final static String MULTINOMIAL_LATENT_LIABILITY_LIKELIHOOD = "multinomialLatentLiabilityLikelihood";
public MultinomialLatentLiabilityLikelihood(TreeModel treeModel, PatternList patternList, CompoundParameter tipTraitParameter, Parameter numClasses) {
super(MULTINOMIAL_LATENT_LIABILITY_LIKELIHOOD);
this.treeModel = treeModel;
this.patternList = patternList;
this.tipTraitParameter = tipTraitParameter;
this.numClasses = numClasses;
addVariable(tipTraitParameter);
setTipDataValuesForAllNodes();
StringBuilder sb = new StringBuilder();
sb.append("Constructing a latent liability likelihood model:\n");
sb.append("\tBinary patterns: ").append(patternList.getId()).append("\n");
sb.append("\tPlease cite:\n").append(Citable.Utils.getCitationString(this));
Logger.getLogger("dr.evomodel.continous").info(sb.toString());
}
private void setTipDataValuesForAllNodes() {
if (tipData == null) {
tipData = new int[treeModel.getExternalNodeCount()][patternList.getPatternCount()];
}
for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
NodeRef node = treeModel.getExternalNode(i);
String id = treeModel.getTaxonId(i);
int index = patternList.getTaxonIndex(id);
setTipDataValuesForNode(node, index);
System.err.println("\t For node: " + i + " with ID " + id + " you get taxon " + index + " with ID " + patternList.getTaxonId(index));
}
}
private void setTipDataValuesForNode(NodeRef node, int index) {
// Set tip data values
int Nindex = node.getNumber();
// if (index != indexFromPatternList) {
// throw new RuntimeException("Need to figure out the indexing");
// }
for (int datum = 0; datum < patternList.getPatternCount(); ++datum) {
tipData[Nindex][datum] = (int) patternList.getPattern(datum)[index];
if (DEBUG) {
Parameter oneTipTraitParameter = tipTraitParameter.getParameter(Nindex);
System.err.println("Data = " + tipData[Nindex][datum] + " : " + oneTipTraitParameter.getParameterValue(datum));
}
}
}
@Override
protected void handleModelChangedEvent(Model model, Object object, int index) {
}
@Override
protected void handleVariableChangedEvent(Variable variable, int index, Variable.ChangeType type) {
likelihoodKnown = false;
}
@Override
protected void storeState() {
storedLogLikelihood = logLikelihood;
}
@Override
protected void restoreState() {
logLikelihood = storedLogLikelihood;
likelihoodKnown = true;
}
@Override
protected void acceptState() {
// do nothing
}
public void makeDirty() {
likelihoodKnown = false;
}
public Model getModel() {
return this;
}
public double getLogLikelihood() {
if (!likelihoodKnown) {
logLikelihood = computeLogLikelihood();
likelihoodKnown = true;
}
return logLikelihood;
}
public String toString() {
return getClass().getName() + "(" + getLogLikelihood() + ")";
}
protected double computeLogLikelihood() {
boolean valid = true;
for (int tip = 0; tip < tipData.length && valid; ++tip) {
valid = validTraitForTip(tip);
}
if (valid) {
return 0.0;
} else {
return Double.NEGATIVE_INFINITY;
}
}
public boolean validTraitForTip(int tip) {
boolean valid = true;
Parameter oneTipTraitParameter = tipTraitParameter.getParameter(tip);
int[] data = tipData[tip];
int LLpointer = 0;
for (int index = 0; index < data.length && valid; ++index) {
int datum = data[index];
int dim = (int) numClasses.getParameterValue(index);
if (dim == 1.0) {
valid = true;
LLpointer++;
} else if (dim == 2.0) {
double trait = oneTipTraitParameter.getParameterValue(LLpointer);
if (trait == 0) {
valid = true;
} else {
boolean positive = trait > 0.0;
if (positive) {
valid = (datum == 1.0);
} else {
valid = (datum == 0.0);
}
}
LLpointer++;
} else {
double[] trait = new double[dim];
for (int l = 0; l < dim; l++) {
trait[l] = oneTipTraitParameter.getParameterValue(LLpointer + l);
}
valid = isMax(trait, datum);
LLpointer += dim;
}
}
return valid;
}
private boolean isMax(double[] trait, int datum) {
boolean isMax = true;
for (int j = 0; j < trait.length && isMax; j++) {
isMax = (trait[datum] >= trait[j]);
}
return isMax;
}
public double getNormalizationConstant(Distribution working) {
return normalizationDelegate.getNormalizationConstant(working); // delegate to abstract Delegate
}
private final LatentTruncation.Delegate normalizationDelegate = new Delegate() {
protected double computeNormalizationConstant(Distribution working) {
double constant = 0.0;
// TODO
return constant;
}
};
public void setPathParameter(double beta){
pathParameter=beta;
}
@Override
public double getLikelihoodCorrection() {
return 0;
}
// **************************************************************
// XMLObjectParser
// **************************************************************
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public final static String TIP_TRAIT = "tipTrait";
public final static String NUM_CLASSES = "numClasses";
public String getParserName() {
return MULTINOMIAL_LATENT_LIABILITY_LIKELIHOOD;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
AbstractMultivariateTraitLikelihood traitLikelihood = (AbstractMultivariateTraitLikelihood)
xo.getChild(AbstractMultivariateTraitLikelihood.class);
PatternList patternList = (PatternList) xo.getChild(PatternList.class);
TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class);
CompoundParameter tipTraitParameter = (CompoundParameter) xo.getElementFirstChild(TIP_TRAIT);
Parameter numClasses = (Parameter) xo.getElementFirstChild(NUM_CLASSES);
int numTaxa = treeModel.getTaxonCount();
int numData = traitLikelihood.getNumData();
int dimTrait = traitLikelihood.getDimTrait();
if (tipTraitParameter.getDimension() != numTaxa * numData * dimTrait) {
throw new XMLParseException("Tip trait parameter is wrong dimension in latent liability model");
}
/*
if (patternList.getPatternCount() != numData * dimTrait) {
throw new XMLParseException("Data is wrong dimension in latent liability model");
}
*/
return new MultinomialLatentLiabilityLikelihood(treeModel, patternList, tipTraitParameter, numClasses);
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "Provides the likelihood of a latent liability model on multivariate trait data";
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
new ElementRule(AbstractMultivariateTraitLikelihood.class, "The model for the latent random variables"),
new ElementRule(TIP_TRAIT, CompoundParameter.class, "The parameter of tip locations from the tree"),
new ElementRule(NUM_CLASSES, Parameter.class, "Number of multinomial classes in each dimention"),
new ElementRule(PatternList.class, "The multinomial tip data"),
new ElementRule(TreeModel.class, "The tree model"),
};
public Class getReturnType() {
return MultinomialLatentLiabilityLikelihood.class;
}
};
@Override
public Citation.Category getCategory() {
return Citation.Category.TRAIT_MODELS;
}
@Override
public String getDescription() {
return "Latent Liability model";
}
@Override
public List<Citation> getCitations() {
List<Citation> citations = new ArrayList<Citation>();
citations.add(CommonCitations.CYBIS_2015_ASSESSING);
return citations;
}
private TreeModel treeModel;
private PatternList patternList;
private CompoundParameter tipTraitParameter;
private Parameter numClasses;
private int[][] tipData;
private boolean likelihoodKnown = false;
private double logLikelihood;
private double storedLogLikelihood;
private static final boolean DEBUG = true;
private double pathParameter=1;
}