/* * TransmissionLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.transmission; import dr.evolution.coalescent.Coalescent; import dr.evolution.coalescent.DemographicFunction; import dr.evolution.coalescent.Intervals; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.util.Taxon; import dr.evolution.util.TaxonList; import dr.evolution.util.Units; import dr.evomodel.coalescent.DemographicModel; import dr.evomodel.tree.TreeModel; import dr.inference.model.AbstractModelLikelihood; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.xml.*; /** * A likelihood function for a complete transmission history. Takes a viruses tree * and a demographic model. The transmission history consists of a number of * hosts with known history of transmission. The viruses tree should have tip * attributes specifying which host they are from (host=""). * * @author Andrew Rambaut * @author Alexei Drummond * @version $Id: TransmissionLikelihood.java,v 1.13 2005/06/15 17:20:54 rambaut Exp $ */ public class TransmissionLikelihood extends AbstractModelLikelihood implements Units { // PUBLIC STUFF public static final String TRANSMISSION_LIKELIHOOD = "transmissionLikelihood"; public static final String SOURCE_PATIENT = "sourcePatient"; public TransmissionLikelihood(Tree hostTree, Tree virusTree, DemographicModel sourceDemographic, TransmissionDemographicModel transmissionModel) throws TaxonList.MissingTaxonException { this(TRANSMISSION_LIKELIHOOD, hostTree, virusTree, sourceDemographic, transmissionModel); } public TransmissionLikelihood(String name, Tree hostTree, Tree virusTree, DemographicModel sourceDemographic, TransmissionDemographicModel transmissionModel) throws TaxonList.MissingTaxonException { super(name); this.hostTree = hostTree; if (hostTree instanceof TreeModel) { addModel((TreeModel) hostTree); } this.virusTree = virusTree; if (virusTree instanceof TreeModel) { addModel((TreeModel) virusTree); } this.sourceDemographic = sourceDemographic; addModel(sourceDemographic); this.transmissionModel = transmissionModel; addModel(transmissionModel); for (int i = 0; i < virusTree.getExternalNodeCount(); i++) { Taxon hostTaxon = (Taxon) virusTree.getTaxonAttribute(i, "host"); if (hostTaxon == null) throw new TaxonList.MissingTaxonException("One or more of the viruses tree's taxa are missing the 'host' attribute"); int host = hostTree.getTaxonIndex(hostTaxon); if (host == -1) throw new TaxonList.MissingTaxonException("One of the viruses tree's host attribute, " + hostTaxon.getId() + ", was not found as a taxon in the host tree"); } setupHosts(); } public TransmissionLikelihood(TransmissionHistoryModel transmissionHistoryModel, Tree virusTree, DemographicModel sourceDemographic, TransmissionDemographicModel transmissionModel) throws TaxonList.MissingTaxonException { this(TRANSMISSION_LIKELIHOOD, transmissionHistoryModel, virusTree, sourceDemographic, transmissionModel); } public TransmissionLikelihood(String name, TransmissionHistoryModel transmissionHistoryModel, Tree virusTree, DemographicModel sourceDemographic, TransmissionDemographicModel transmissionModel) throws TaxonList.MissingTaxonException { super(name); this.transmissionHistoryModel = transmissionHistoryModel; addModel((TransmissionHistoryModel) transmissionHistoryModel); this.virusTree = virusTree; if (virusTree instanceof TreeModel) { addModel((TreeModel) virusTree); } this.sourceDemographic = sourceDemographic; addModel(sourceDemographic); this.transmissionModel = transmissionModel; addModel(transmissionModel); for (int i = 0; i < virusTree.getExternalNodeCount(); i++) { Taxon hostTaxon = (Taxon) virusTree.getTaxonAttribute(i, "host"); if (hostTaxon == null) throw new TaxonList.MissingTaxonException("One or more of the viruses tree's taxa are missing the 'host' attribute"); int host = transmissionHistoryModel.getHostIndex(hostTaxon); if (host == -1) throw new TaxonList.MissingTaxonException("One of the viruses tree's host attribute, " + hostTaxon.getId() + ", was not found as a taxon in the transmission history"); } setupHosts(); } private void setupHosts() { if (transmissionHistoryModel != null) { hostCount = transmissionHistoryModel.getHostCount(); } else { hostCount = hostTree.getTaxonCount(); } intervals = new Intervals[hostCount]; for (int i = 0; i < hostCount; i++) { // 3 times virusTree tip count will be enough events... intervals[i] = new Intervals(virusTree.getExternalNodeCount() * 3); } donorHost = new int[hostCount]; donorHost[0] = -1; transmissionTime = new double[hostCount]; transmissionTime[0] = Double.POSITIVE_INFINITY; donorSize = new double[hostCount]; if (transmissionHistoryModel != null) { for (int i = 0; i < transmissionHistoryModel.getTransmissionEventCount(); i++) { TransmissionHistoryModel.TransmissionEvent event = transmissionHistoryModel.getTransmissionEvent(i); int host1 = transmissionHistoryModel.getHostIndex(event.getDonor()); int host2 = transmissionHistoryModel.getHostIndex(event.getRecipient()); donorHost[host2] = host1; transmissionTime[host2] = event.getTransmissionTime(); } } else { setupHostsTree(hostTree.getRoot()); } } private int setupHostsTree(NodeRef node) { int host; if (hostTree.isExternal(node)) { host = node.getNumber(); } else { // This traversal assumes that the first child is the donor // and the second is the recipient int host1 = setupHostsTree(hostTree.getChild(node, 0)); int host2 = setupHostsTree(hostTree.getChild(node, 1)); donorHost[host2] = host1; transmissionTime[host2] = hostTree.getNodeHeight(node); host = host1; } return host; } // ************************************************************** // ModelListener IMPLEMENTATION // ************************************************************** protected final void handleModelChangedEvent(Model model, Object object, int index) { if (model == virusTree) { // treeModel has changed so recalculate the intervals } else if (model == hostTree) { // hosts treeModel has changed so recalculate the hosts and intervals } else if (model == transmissionHistoryModel) { // transmissionHistoryModel has changed so recalculate the hosts and intervals } else { // demographicModel has changed so we don't need to recalculate the intervals } likelihoodKnown = false; } // ************************************************************** // VariableListener IMPLEMENTATION // ************************************************************** protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { } // ************************************************************** // Model IMPLEMENTATION // ************************************************************** /** * Stores the precalculated state: in this case the intervals */ protected final void storeState() { } /** * Restores the precalculated state: that is the intervals of the tree. */ protected final void restoreState() { likelihoodKnown = false; } protected final void acceptState() { } // nothing to do // ************************************************************** // Likelihood IMPLEMENTATION // ************************************************************** public final Model getModel() { return this; } public final double getLogLikelihood() { if (!likelihoodKnown) { logLikelihood = calculateLogLikelihood(); likelihoodKnown = true; } return logLikelihood; } public final void makeDirty() { likelihoodKnown = false; } /** * Calculates the log likelihood of this set of coalescent intervals, * given a demographic model. */ public double calculateLogLikelihood() { makeDirty(); setupHosts(); for (int i = 0; i < hostCount; i++) { intervals[i].resetEvents(); donorSize[i] = -1; } try { setupIntervals(virusTree.getRoot()); } catch (IncompatibleException re) { // register the compatibility failure return Double.NEGATIVE_INFINITY; } /* if (intervalsKnown == false) { if (!hostsKnown) { setupHosts(); hostsKnown = true; } for (int i = 0; i < hostCount; i++) { intervals[i].resetEvents(); donorSize[i] = -1; } try { setupIntervals(virusTree.getRoot()); if (isCompatible(virusTree.getRoot()) == -1) { System.out.println("compatibility failed!"); makeDirty(); // register the compatibility failure return Double.NEGATIVE_INFINITY; } //System.out.println("intervals set up successfully!"); intervalsKnown = true; savedHostTree = new FlexibleTree(hostTree); } catch (IncompatibleException re) { System.out.println("intervals setup failed!"); if (savedHostTree == null) { throw new RuntimeException(re.getMessage()); } else { makeDirty(); // register the compatibility failure return Double.NEGATIVE_INFINITY; } } }*/ for (int i = 0; i < hostCount; i++) { donorSize[i] = -1; } DemographicFunction demoFunction = sourceDemographic.getDemographicFunction(); double logL = Coalescent.calculateLogLikelihood(intervals[0], demoFunction); for (int i = 1; i < hostCount; i++) { double ds = getDonorSize(i); demoFunction = transmissionModel.getDemographicFunction(transmissionTime[i], ds, i); logL += Coalescent.calculateLogLikelihood(intervals[i], demoFunction); } return logL; } private double getDonorSize(int host) { if (donorSize[host] > 0.0) { return donorSize[host]; } DemographicFunction demoFunction; if (donorHost[host] == 0) { demoFunction = sourceDemographic.getDemographicFunction(); } else { double ds = getDonorSize(donorHost[host]); demoFunction = transmissionModel.getDemographicFunction(transmissionTime[host], ds, host); } donorSize[host] = demoFunction.getDemographic(transmissionTime[host]); return donorSize[host]; } private int setupIntervals(NodeRef node) throws IncompatibleException { double height = virusTree.getNodeHeight(node); int host; if (virusTree.isExternal(node)) { Taxon hostTaxon = (Taxon) virusTree.getTaxonAttribute(node.getNumber(), "host"); if (transmissionHistoryModel != null) { host = transmissionHistoryModel.getHostIndex(hostTaxon); } else { host = hostTree.getTaxonIndex(hostTaxon); } intervals[host].addSampleEvent(height); } else { // Tree should be bifurcating... int host1 = setupIntervals(virusTree.getChild(node, 0)); int host2 = setupIntervals(virusTree.getChild(node, 1)); while (height > transmissionTime[host1]) { double time = transmissionTime[host1]; intervals[host1].addNothingEvent(time); host1 = donorHost[host1]; intervals[host1].addSampleEvent(time); } while (height > transmissionTime[host2]) { double time = transmissionTime[host2]; intervals[host2].addNothingEvent(time); host2 = donorHost[host2]; intervals[host2].addSampleEvent(time); } if (host1 != host2) { throw new IncompatibleException("Virus tree is not compatible with transmission history"); } host = host1; intervals[host].addCoalescentEvent(height); } return host; } // ************************************************************** // Units IMPLEMENTATION // ************************************************************** /** * Sets the units these coalescent intervals are * measured in. */ public final void setUnits(Type u) { transmissionModel.setUnits(u); } /** * Returns the units these coalescent intervals are * measured in. */ public final Type getUnits() { return transmissionModel.getUnits(); } public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserName() { return TRANSMISSION_LIKELIHOOD; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { DemographicModel demoModel0 = (DemographicModel) xo.getElementFirstChild(SOURCE_PATIENT); TransmissionDemographicModel demoModel1 = (TransmissionDemographicModel) xo.getChild(TransmissionDemographicModel.class); Tree virusTree = (Tree) xo.getElementFirstChild("parasiteTree"); TransmissionLikelihood likelihood = null; if (xo.getChild(TransmissionHistoryModel.class) != null) { TransmissionHistoryModel history = (TransmissionHistoryModel) xo.getChild(TransmissionHistoryModel.class); try { likelihood = new TransmissionLikelihood(history, virusTree, demoModel0, demoModel1); } catch (TaxonList.MissingTaxonException e) { throw new XMLParseException(e.toString()); } } else { Tree hostTree = (Tree) xo.getElementFirstChild("hostTree"); try { likelihood = new TransmissionLikelihood(hostTree, virusTree, demoModel0, demoModel1); } catch (TaxonList.MissingTaxonException e) { throw new XMLParseException(e.toString()); } } return likelihood; } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "This element represents a likelihood function for transmission."; } public Class getReturnType() { return TransmissionLikelihood.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules;{ rules = new XMLSyntaxRule[]{ new ElementRule(SOURCE_PATIENT, DemographicModel.class, "This describes the demographic process for the source donor patient."), new ElementRule(TransmissionDemographicModel.class, "This describes the demographic process for the recipient patients."), new XORRule( new ElementRule("hostTree", new XMLSyntaxRule[]{new ElementRule(Tree.class)}), new ElementRule(TransmissionHistoryModel.class, "This describes the transmission history of the patients.") ), new ElementRule("parasiteTree", new XMLSyntaxRule[]{new ElementRule(Tree.class)}) }; } }; class IncompatibleException extends Exception { /** * */ private static final long serialVersionUID = 8439923064799668934L; public IncompatibleException(String name) { super(name); } } /** * The demographic models. */ private DemographicModel sourceDemographic = null; private TransmissionDemographicModel transmissionModel = null; /** * The host tree. */ private Tree hostTree = null; private TransmissionHistoryModel transmissionHistoryModel = null; /** * The viruses tree. */ private Tree virusTree = null; /** * The number of hosts. */ private int hostCount; /** * The intervals for each host. */ private Intervals[] intervals; /** * The donor host for each recipient host (-1 for initial host). */ private int[] donorHost; /** * The time of transmission into this host (POSITIVE_INFINITY for initial host). */ private double[] transmissionTime; /** * The size of the donor population at time of transmission into recipient host. */ private double[] donorSize; private boolean likelihoodKnown = false; private double logLikelihood; }