/* * InhibitionAssayLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.continuous; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evomodel.tree.TreeModel; import dr.inference.loggers.LogColumn; import dr.inference.loggers.NumberColumn; import dr.inference.model.*; import dr.xml.*; import org.w3c.dom.Document; import org.w3c.dom.Element; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.logging.Logger; /** * @author Marc Suchard * @author Jack O'Brien */ public class InhibitionAssayLikelihood extends AbstractModelLikelihood /*implements NodeAttributeProvider*/ { public static final String TRAIT_LIKELIHOOD = "inhibitionLikelihood"; public static final String TRAIT_NAME = "traitName"; public static final String ROOT_PRIOR = "rootPrior"; public static final String MODEL = "diffusionModel"; public static final String TREE = "tree"; public static final String TRAIT_PARAMETER = "traitParameter"; public static final String SET_TRAIT = "setOutcomes"; public static final String MISSING = "missingIndicator"; public static final String CACHE_BRANCHES = "cacheBranches"; public static final String IN_REAL_TIME = "inRealTime"; public static final String PRECISION = "precision"; public InhibitionAssayLikelihood(TreeModel treeModel, // List<Integer> missingIndices, MatrixParameter dataParameter, Parameter precision) { super(TRAIT_LIKELIHOOD); this.treeModel = treeModel; this.dataParameter = dataParameter; addModel(treeModel); addVariable(dataParameter); addVariable(precision); N = treeModel.getExternalNodeCount(); StringBuffer sb = new StringBuffer("Creating inhibition assay model:\n"); sb.append("\tPlease cite O'Brien and Suchard (in preparation) if you publish results using this model."); Logger.getLogger("dr.evomodel").info(sb.toString()); } // ************************************************************** // ModelListener IMPLEMENTATION // ************************************************************** protected void handleModelChangedEvent(Model model, Object object, int index) { likelihoodKnown = false; } private double getBranchMean(TreeModel tree, NodeRef node) { if (tree.isRoot(node)) { return tree.getNodeTrait(node, "mean"); } else { double rate; if (isClusterChangeOnBranchAbove(tree, node)) { rate = tree.getNodeTrait(node, "mean"); } else { rate = getBranchMean(tree, tree.getParent(node)); } return rate; } } public final boolean isClusterChangeOnBranchAbove(TreeModel tree, NodeRef node) { return tree.getNodeTrait(node, "indicator") == 1; } public final boolean areNodesInSameCluster(TreeModel tree, NodeRef node1, NodeRef node2) { return clusterStart(tree, node1) == clusterStart(tree, node2); } public final NodeRef clusterStart(TreeModel tree, NodeRef node) { if (tree.isRoot(node) || isClusterChangeOnBranchAbove(tree, node)) return node; return clusterStart(tree, tree.getParent(node)); } // ************************************************************** // VariableListener IMPLEMENTATION // ************************************************************** protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { likelihoodKnown = false; } // ************************************************************** // Model IMPLEMENTATION // ************************************************************** /** * Stores the precalculated state: in this case the intervals */ protected void storeState() { storedLikelihoodKnown = likelihoodKnown; storedLogLikelihood = logLikelihood; storedTreeLength = treeLength; } /** * Restores the precalculated state: that is the intervals of the tree. */ protected void restoreState() { likelihoodKnown = storedLikelihoodKnown; logLikelihood = storedLogLikelihood; treeLength = storedTreeLength; } protected void acceptState() { } // nothing to do public TreeModel getTreeModel() { return treeModel; } // public MultivariateDiffusionModel getDiffusionModel() { // return diffusionModel; // } // // public boolean getInSubstitutionTime() { // return inSubstitutionTime; // } // ************************************************************** // Likelihood IMPLEMENTATION // ************************************************************** public Model getModel() { return this; } public String toString() { return getClass().getName() + "(" + getLogLikelihood() + ")"; } public final double getLogLikelihood() { if (!likelihoodKnown) { logLikelihood = calculateLogLikelihood(); likelihoodKnown = true; } return logLikelihood; } public void makeDirty() { likelihoodKnown = false; } /** * Calculate the log likelihood of the current state. * * @return the log likelihood. */ public double calculateLogLikelihood() { logLikelihood = 0; // for(int i=0; i<treeModel.getNodeCount(); i++) { // treeModel.setNodeTrait(treeModel.getNode(i),"indicator",0.0); // } // // treeModel.setNodeTrait(treeModel.getInternalNode(2),"indicator",1.0); // // // treeModel.setNodeTrait(treeModel.getRoot(),"mean",2.0); // // // for(int i=0; i<treeModel.getExternalNodeCount(); i++) { // System.err.println("mean for tip "+i+" = "+getBranchMean(treeModel,treeModel.getExternalNode(i))); // } double[] mean = new double[N]; for (int i = 0; i < N; i++) mean[i] = getBranchMean(treeModel, treeModel.getExternalNode(i)); final double[][] data = dataParameter.getParameterAsMatrix(); final boolean[][] commonCluster = determineCommonClusters(); for (int i = 0; i < N; i++) { for (int j = i; j < N; j++) { if (i != j) { if (commonCluster[i][j]) { // todo do something to logLikelihood } else { // todo do something else to logLikelihood } } } } return logLikelihood; } private boolean[][] determineCommonClusters() { boolean[][] commonCluster = new boolean[N][N]; for (int i = 0; i < N; i++) { for (int j = i; j < N; j++) { if (i != j) { NodeRef nodeI = treeModel.getExternalNode(i); NodeRef nodeJ = treeModel.getExternalNode(j); if (areNodesInSameCluster(treeModel, nodeI, nodeJ)) { commonCluster[i][j] = commonCluster[j][i] = true; } else { commonCluster[i][j] = commonCluster[j][i] = false; } } } } return commonCluster; } public double getMaxLogLikelihood() { return maxLogLikelihood; } public int[] getRestrictedGrowthFunction() { boolean[][] commonCluster = determineCommonClusters(); int totalClusters = 1; int[] map = new int[N]; map[0] = 0; // first taxon is always in first cluster for (int i = 1; i < N; i++) { // iterate over all remaining taxa boolean notFound = true; for (int j = 0; notFound && j < i; j++) { if (commonCluster[i][j]) { notFound = false; map[i] = map[j]; // i and j are in the same cluster } } if (notFound) { // i is in a new cluster map[i] = totalClusters; totalClusters++; } } return map; } public String getClusterString() { // returns the restricted growth representation of the clusters int[] map = getRestrictedGrowthFunction(); StringBuffer sb = new StringBuffer("{"); sb.append(map[0]); for (int i = 1; i < N; i++) sb.append("," + map[i]); sb.append("}"); return sb.toString(); } // ************************************************************** // Loggable IMPLEMENTATION // ************************************************************** /** * @return the log columns. */ public LogColumn[] getColumns() { return new LogColumn[]{ new LikelihoodColumn(getId()), // new NumberClustersColumn(getId()), new ClustersColumn(getId()) }; } private String[] attributeLabel = null; public String[] getNodeAttributeLabel() { if (attributeLabel == null) { double[] trait = treeModel.getMultivariateNodeTrait(treeModel.getRoot(), "trait"); attributeLabel = new String[trait.length]; if (trait.length == 1) attributeLabel[0] = traitName; else { for (int i = 1; i <= trait.length; i++) attributeLabel[i - 1] = traitName + i; } } return attributeLabel; } public String[] getAttributeForNode(Tree tree, NodeRef node) { double trait[] = treeModel.getMultivariateNodeTrait(node, "trait"); // StringBuffer sb = new StringBuffer(); // sb.append("{"); // for(int i=0; i<trait.length-1; i++) { // sb.append(trait[i]); // sb.append(","); // } // sb.append(trait[trait.length-1]); // sb.append("}"); String[] value = new String[trait.length]; for (int i = 0; i < trait.length; i++) value[i] = Double.toString(trait[i]); // return new String[] {sb.toString()}; //To change body of implemented methods use File | Settings | File Templates. return value; } private class LikelihoodColumn extends NumberColumn { public LikelihoodColumn(String label) { super(label); } public double getDoubleValue() { return getLogLikelihood(); } } // private class NumberClustersColumn extends NumberColumn { // // public NumberClustersColumn(String label) { // super(label); // } // // public double getDoubleValue() { // int total = 1; // // return total; // } // } private class ClustersColumn extends LogColumn.Abstract { public ClustersColumn(String label) { super(label); } protected String getFormattedValue() { return getClusterString(); } } // ************************************************************** // XMLElement IMPLEMENTATION // ************************************************************** public Element createElement(Document d) { throw new RuntimeException("Not implemented yet!"); } // ************************************************************** // XMLObjectParser // ************************************************************** public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserName() { return TRAIT_LIKELIHOOD; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { // System.err.println("did i get here?"); // MultivariateDiffusionModel diffusionModel = (MultivariateDiffusionModel) xo.getChild(MultivariateDiffusionModel.class); TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class); // CompoundParameter traitParameter = (CompoundParameter) xo.getSocketChild(TRAIT_PARAMETER); MatrixParameter data = (MatrixParameter) xo.getChild(MatrixParameter.class); XMLObject cxo = xo.getChild(PRECISION); Parameter precision = (Parameter) cxo.getChild(Parameter.class); int numTips = treeModel.getExternalNodeCount(); if (numTips != data.getColumnDimension() || numTips != data.getRowDimension()) throw new XMLParseException("Dimensions of matrix '" + data.getId() + "' do not match the number of taxa in '" + treeModel.getId() + "'"); // boolean cacheBranches = false; // if (xo.hasAttribute(CACHE_BRANCHES)) // cacheBranches = xo.getBooleanAttribute(CACHE_BRANCHES); // // boolean inSubstitutionTime = false; // if (xo.hasAttribute(IN_REAL_TIME)) // inSubstitutionTime = !xo.getBooleanAttribute(IN_REAL_TIME); // List<Integer> missingIndices = null; // String traitName = "trait"; // if (xo.hasAttribute(TRAIT_NAME)) { // // traitName = xo.getStringAttribute(TRAIT_NAME); // // // Fill in attributeValues // int taxonCount = treeModel.getTaxonCount(); // for (int i = 0; i < taxonCount; i++) { // String taxonName = treeModel.getTaxonId(i); // System.err.println("taxon "+i+" = "+taxonName); // String paramName = taxonName + ".trait"; // Parameter traitParam = getTraitParameterByName(traitParameter, paramName); // if (traitParam == null) // throw new RuntimeException("Missing trait parameters at tree tips"); // String object = (String) treeModel.getTaxonAttribute(i, traitName); // if (object == null) // throw new RuntimeException("Trait \"" + traitName + "\" not found for taxa \"" + taxonName + "\""); // else { // StringTokenizer st = new StringTokenizer(object); // int count = st.countTokens(); // if (count != traitParam.getDimension()) // throw new RuntimeException("Trait length must match trait parameter dimension"); // for (int j = 0; j < count; j++) { // String oneValue = st.nextToken(); // double value = Double.NaN; // if (oneValue.compareTo("NA") == 0) { // // Missing values not yet handled. // } else { // try { // value = (new Double(oneValue)).doubleValue(); // } catch (NumberFormatException e) { // throw new RuntimeException(e.getMessage()); // } // } // traitParam.setParameterValue(j, value); // } // } // } // // Find missing values // double[] allValues = traitParameter.getParameterValues(); // missingIndices = new ArrayList<Integer>(); // for (int i = 0; i < allValues.length; i++) { // if ((new Double(allValues[i])).isNaN()) { // traitParameter.setParameterValue(i, 0); // missingIndices.add(i); // } // } // // if (xo.hasSocket(MISSING)) { // XMLObject cxo = (XMLObject) xo.getChild(MISSING); // Parameter missingParameter = new Parameter.Default(allValues.length, 0.0); // for (int i : missingIndices) { // missingParameter.setParameterValue(i, 1.0); // } // missingParameter.addBounds(new Parameter.DefaultBounds(1.0, 0.0, allValues.length)); ///* CompoundParameter missingParameter = new CompoundParameter(MISSING); // System.err.println("TRAIT: "+traitParameter.toString()); // System.err.println("CNT: "+traitParameter.getNumberOfParameters()); // for(int i : missingIndices) { // Parameter thisParameter = traitParameter.getIndicatorParameter(i); // missingParameter.addVariable(thisParameter); // }*/ // replaceParameter(cxo, missingParameter); // } // } return new InhibitionAssayLikelihood(treeModel, data, precision); } private Parameter getTraitParameterByName(CompoundParameter traits, String name) { // Parameter found = null; // System.err.println("LOOKING FOR: "+name); for (int i = 0; i < traits.getParameterCount(); i++) { Parameter found = traits.getParameter(i); // System.err.println("COMPARE TO: "+found.getStatisticName()); if (found.getStatisticName().compareTo(name) == 0) return found; } return null; } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "Provides the likelihood of a continuous trait evolving on a tree by a " + "given diffusion model."; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { // new StringAttributeRule(TRAIT_NAME, "The name of the trait for which a likelihood should be calculated"), // AttributeRule.newBooleanRule(IN_REAL_TIME, true), // new ElementRule(MultivariateDiffusionModel.class), new ElementRule(TreeModel.class), new ElementRule(MatrixParameter.class), // new ElementRule("precision", Parameter) // new ElementRule(Parameter.class) new ElementRule(PRECISION, new XMLSyntaxRule[]{ new ElementRule(Parameter.class)}) }; public Class getReturnType() { return AbstractMultivariateTraitLikelihood.class; } // public void replaceParameter(XMLObject xo, Parameter newParam) throws XMLParseException { // // for (int i = 0; i < xo.getChildCount(); i++) { // // if (xo.getChild(i) instanceof Parameter) { // // XMLObject rxo = null; // Object obj = xo.getRawChild(i); // // if (obj instanceof Reference) { // rxo = ((Reference) obj).getReferenceObject(); // } else if (obj instanceof XMLObject) { // rxo = (XMLObject) obj; // } else { // throw new XMLParseException("object reference not available"); // } // // if (rxo.getChildCount() > 0) { // throw new XMLParseException("No child elements allowed in parameter element."); // } // // if (rxo.hasAttribute(XMLParser.IDREF)) { // throw new XMLParseException("References to " + xo.getName() + " parameters are not allowed in treeModel."); // } // // if (rxo.hasAttribute(ParameterParser.VALUE)) { // throw new XMLParseException("Parameters in " + xo.getName() + " have values set automatically."); // } // // if (rxo.hasAttribute(ParameterParser.UPPER)) { // throw new XMLParseException("Parameters in " + xo.getName() + " have bounds set automatically."); // } // // if (rxo.hasAttribute(ParameterParser.LOWER)) { // throw new XMLParseException("Parameters in " + xo.getName() + " have bounds set automatically."); // } // // if (rxo.hasAttribute(XMLParser.ID)) { // // newParam.setId(rxo.getStringAttribute(XMLParser.ID)); // } // // rxo.setNativeObject(newParam); // // return; // } // } // } }; private TreeModel treeModel = null; private MatrixParameter dataParameter = null; private final Parameter precision = null; private final int N; MultivariateDiffusionModel diffusionModel = null; String traitName = null; // private boolean jeffreysPrior = false; CompoundParameter traitParameter; List<Integer> missingIndices; ArrayList dataList = new ArrayList(); private double logLikelihood; private final double maxLogLikelihood = Double.NEGATIVE_INFINITY; private double storedLogLikelihood; private boolean likelihoodKnown = false; private boolean storedLikelihoodKnown = false; // private double[] cachedLikelihoods = null; private final HashMap<NodeRef, Double> cachedLikelihoods = null; private double treeLength; private double storedTreeLength; private boolean inSubstitutionTime; }