/* * WithinCaseCoalescent.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.epidemiology.casetocase; import dr.app.tools.NexusExporter; import dr.evolution.coalescent.*; import dr.evolution.tree.FlexibleNode; import dr.evolution.tree.FlexibleTree; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.util.*; import dr.evomodel.coalescent.DemographicModel; import dr.evomodel.tree.TreeModel; import dr.inference.loggers.LogColumn; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.math.*; import dr.xml.*; import java.io.IOException; import java.io.PrintStream; import java.math.BigDecimal; import java.util.*; /** * Intended to replace the tree prior; each subtree in the partition is considered a tree in its own right generated by * a coalescent process * * @author Matthew Hall */ public class WithinCaseCoalescent extends CaseToCaseTreeLikelihood { public static final String WITHIN_CASE_COALESCENT = "withinCaseCoalescent"; private enum Mode {TRUNCATE, NORMAL} private double[] partitionTreeLogLikelihoods; private double[] storedPartitionTreeLogLikelihoods; private boolean[] recalculateCoalescentFlags; private DemographicModel demoModel; private Mode mode; private double coalescencesLogLikelihood; private double storedCoalescencesLogLikelihood; private boolean pleaseReExplode = true; public WithinCaseCoalescent(PartitionedTreeModel virusTree, AbstractOutbreak caseData, Parameter maxFirstInfToRoot, DemographicModel demoModel, Mode mode) throws TaxonList.MissingTaxonException { super(WITHIN_CASE_COALESCENT, virusTree, caseData, maxFirstInfToRoot); this.mode = mode; this.demoModel = demoModel; addModel(demoModel); addModel(outbreak); partitionTreeLogLikelihoods = new double[outbreak.getCases().size()]; storedPartitionTreeLogLikelihoods = new double[outbreak.getCases().size()]; recalculateCoalescentFlags = new boolean[outbreak.getCases().size()]; Arrays.fill(recalculateCoalescentFlags, true); elementsAsTrees = new HashMap<AbstractCase, Treelet>(); for(AbstractCase aCase: outbreak.getCases()){ if(aCase.wasEverInfected()){ elementsAsTrees.put(aCase, null); } } storedElementsAsTrees = new HashMap<AbstractCase, Treelet>(); } protected double calculateLogLikelihood(){ //checkPartitions(); if(pleaseReExplode){ explodeTree(); } double logL = 0; coalescencesLogLikelihood = 0; for(AbstractCase aCase : outbreak.getCases()){ int number = outbreak.getCaseIndex(aCase); if(aCase.wasEverInfected()) { // and then the little tree calculations if (recalculateCoalescentFlags[number]) { Treelet treelet = elementsAsTrees.get(aCase); if (treelet.getExternalNodeCount() > 1) { SpecifiedZeroCoalescent coalescent = new SpecifiedZeroCoalescent(treelet, demoModel, treelet.getZeroHeight(), mode == Mode.TRUNCATE); partitionTreeLogLikelihoods[number] = coalescent.calculateLogLikelihood(); coalescencesLogLikelihood += partitionTreeLogLikelihoods[number]; } else { partitionTreeLogLikelihoods[number] = 0.0; } recalculateCoalescentFlags[number] = false; } else { coalescencesLogLikelihood += partitionTreeLogLikelihoods[number]; } } else { recalculateCoalescentFlags[number] = false; } } logL += coalescencesLogLikelihood; likelihoodKnown = true; return logL; } public void storeState(){ super.storeState(); storedElementsAsTrees = new HashMap<AbstractCase, Treelet>(elementsAsTrees); storedPartitionTreeLogLikelihoods = Arrays.copyOf(partitionTreeLogLikelihoods, partitionTreeLogLikelihoods.length); storedCoalescencesLogLikelihood = coalescencesLogLikelihood; } public void restoreState(){ super.restoreState(); elementsAsTrees = storedElementsAsTrees; partitionTreeLogLikelihoods = storedPartitionTreeLogLikelihoods; coalescencesLogLikelihood = storedCoalescencesLogLikelihood; } protected void handleModelChangedEvent(Model model, Object object, int index) { super.handleModelChangedEvent(model, object, index); if(model == treeModel){ if(object instanceof PartitionedTreeModel.PartitionsChangedEvent){ HashSet<AbstractCase> changedPartitions = ((PartitionedTreeModel.PartitionsChangedEvent)object).getCasesToRecalculate(); for(AbstractCase aCase : changedPartitions){ recalculateCaseWCC(aCase); } } } else if(model == getBranchMap()){ if(object instanceof ArrayList){ for(int i=0; i<((ArrayList) object).size(); i++){ BranchMapModel.BranchMapChangedEvent event = (BranchMapModel.BranchMapChangedEvent)((ArrayList) object).get(i); recalculateCaseWCC(event.getOldCase()); recalculateCaseWCC(event.getNewCase()); NodeRef node = treeModel.getNode(event.getNodeToRecalculate()); NodeRef parent = treeModel.getParent(node); if(parent!=null){ recalculateCaseWCC(getBranchMap().get(parent.getNumber())); } } } else { throw new RuntimeException("Unanticipated model changed event from BranchMapModel"); } } else if(model == demoModel){ Arrays.fill(recalculateCoalescentFlags, true); } else if(model == outbreak){ if(object instanceof AbstractCase){ AbstractCase thisCase = (AbstractCase)object; recalculateCaseWCC(thisCase); AbstractCase parent = ((PartitionedTreeModel)treeModel).getInfector(thisCase); if(parent!=null){ recalculateCaseWCC(parent); } } } } protected void recalculateCaseWCC(int index){ elementsAsTrees.put(outbreak.getCase(index), null); pleaseReExplode = true; recalculateCoalescentFlags[index] = true; } protected void recalculateCaseWCC(AbstractCase aCase){ if(aCase.wasEverInfected()) { recalculateCaseWCC(outbreak.getCaseIndex(aCase)); } } public void makeDirty(){ super.makeDirty(); Arrays.fill(recalculateCoalescentFlags, true); for(AbstractCase aCase : outbreak.getCases()){ if(aCase.wasEverInfected()) { elementsAsTrees.put(aCase, null); } } pleaseReExplode = true; } // Tears the tree into small pieces. Indexes correspond to indexes in the outbreak. public ArrayList<AbstractCase> postOrderTransmissionTreeTraversal(){ return traverseTransmissionTree(getBranchMap().get(treeModel.getRoot().getNumber())); } private ArrayList<AbstractCase> traverseTransmissionTree(AbstractCase aCase){ ArrayList<AbstractCase> out = new ArrayList<AbstractCase>(); HashSet<AbstractCase> children = ((PartitionedTreeModel)treeModel).getInfectees(aCase); for(int i=0; i<getOutbreak().size(); i++){ AbstractCase possibleChild = getOutbreak().getCase(i); // easiest way to maintain the set ordering of the outbreak? if(children.contains(possibleChild)){ out.addAll(traverseTransmissionTree(possibleChild)); } } out.add(aCase); return out; } private Treelet transformTreelet(Treelet treelet){ double[] transformedNodeTimes = new double[treelet.getNodeCount()]; double totalHeight = treelet.getZeroHeight(); double willMapToZero = totalHeight - 1; for(int i=0; i<treelet.getNodeCount(); i++){ NodeRef node = treelet.getNode(i); double time = treelet.getNodeHeight(node) - totalHeight; transformedNodeTimes[i] = -Math.log(-(time)); } double first = Double.POSITIVE_INFINITY; for (double transformedNodeTime : transformedNodeTimes) { if (transformedNodeTime < first) { first = transformedNodeTime; } } double zeroHeight = -first; Treelet copy = new Treelet(treelet, zeroHeight); for(int i=0; i<copy.getNodeCount(); i++){ NodeRef node = copy.getNode(i); copy.setNodeHeight(node, transformedNodeTimes[i] - first); } copy.resolveTree(); return copy; } private class SpecifiedZeroCoalescent extends Coalescent { private double zeroHeight; boolean truncate; private SpecifiedZeroCoalescent(Tree tree, DemographicModel demographicModel, double zeroHeight, boolean truncate){ super(tree, demographicModel.getDemographicFunction()); this.zeroHeight = zeroHeight; this.truncate = truncate; } public double calculateLogLikelihood() { return calculatePartitionTreeLogLikelihood(getIntervals(), getDemographicFunction(), 0, zeroHeight, truncate); } } public static double calculatePartitionTreeLogLikelihood(IntervalList intervals, DemographicFunction demographicFunction, double threshold, double zeroHeight, boolean truncate) { double logL = 0.0; double startTime = -zeroHeight; final int n = intervals.getIntervalCount(); //TreeIntervals sets up a first zero-length interval with a lineage count of zero - skip this one for (int i = 0; i < n; i++) { if(truncate) { // time zero corresponds to the date of first infection final double duration = intervals.getInterval(i); final double finishTime = startTime + duration; // if this has happened the run is probably pretty unhappy if (finishTime == 0) { return Double.NEGATIVE_INFINITY; } final double intervalArea = demographicFunction.getIntegral(startTime, finishTime); final double normalisationArea = demographicFunction.getIntegral(startTime, 0); if (intervalArea == 0 && duration > tolerance) { return Double.NEGATIVE_INFINITY; } final int lineageCount = intervals.getLineageCount(i); if (lineageCount >= 2) { final double kChoose2 = Binomial.choose2(lineageCount); if (intervals.getIntervalType(i) == IntervalType.COALESCENT) { logL += -kChoose2 * intervalArea; final double demographicAtCoalPoint = demographicFunction.getDemographic(finishTime); if (duration == 0.0 || demographicAtCoalPoint * (intervalArea / duration) >= threshold) { logL -= Math.log(demographicAtCoalPoint); } else { return Double.NEGATIVE_INFINITY; } } else { double numerator = Math.exp(-kChoose2 * intervalArea) - Math.exp(-kChoose2 * normalisationArea); logL += Math.log(numerator); } // normalisation double normExp = Math.exp(-kChoose2 * normalisationArea); double logDenominator; // the denominator has an irritating tendency to round to zero if (normExp != 1) { logDenominator = Math.log1p(-normExp); } else { logDenominator = handleDenominatorUnderflow(-kChoose2 * normalisationArea); } logL -= logDenominator; } startTime = finishTime; } else { if(!(demographicFunction instanceof LinearGrowth)){ throw new RuntimeException("Function must have zero population at t=0 if truncate=false"); } final double duration = intervals.getInterval(i); final double finishTime = startTime + duration; final double intervalArea = demographicFunction.getIntegral(startTime, finishTime); if( intervalArea == 0 && duration != 0 ) { return Double.NEGATIVE_INFINITY; } final int lineageCount = intervals.getLineageCount(i); final double kChoose2 = Binomial.choose2(lineageCount); // common part logL += -kChoose2 * intervalArea; if (intervals.getIntervalType(i) == IntervalType.COALESCENT) { final double demographicAtCoalPoint = demographicFunction.getDemographic(finishTime); // if value at end is many orders of magnitude different than mean over interval reject the interval // This is protection against cases where ridiculous infinitesimal population size at the end of a // linear interval drive coalescent values to infinity. if( duration == 0.0 || demographicAtCoalPoint * (intervalArea/duration) >= threshold ) { // if( duration == 0.0 || demographicAtCoalPoint >= threshold * (duration/intervalArea) ) { logL -= Math.log(demographicAtCoalPoint); } else { // remove this at some stage // System.err.println("Warning: " + i + " " + demographicAtCoalPoint + " " + (intervalArea/duration) ); return Double.NEGATIVE_INFINITY; } } startTime = finishTime; } } return logL; } private static double handleDenominatorUnderflow(double input){ BigDecimal bigDec = new BigDecimal(input); BigDecimal expBigDec = BigDecimalUtils.exp(bigDec, bigDec.scale()); BigDecimal one = new BigDecimal(1.0); BigDecimal oneMinusExpBigDec = one.subtract(expBigDec); BigDecimal logOneMinusExpBigDec = BigDecimalUtils.ln(oneMinusExpBigDec, oneMinusExpBigDec.scale()); return logOneMinusExpBigDec.doubleValue(); } public void debugTreelet(Tree treelet, String fileName){ try{ FlexibleTree treeCopy = new FlexibleTree(treelet); for(int j=0; j<treeCopy.getNodeCount(); j++){ FlexibleNode node = (FlexibleNode)treeCopy.getNode(j); node.setAttribute("Number", node.getNumber()); } NexusExporter testTreesOut = new NexusExporter(new PrintStream(fileName)); testTreesOut.exportTree(treeCopy); } catch (IOException ignored) {System.out.println("IOException");} } public LogColumn[] passColumns(){ ArrayList<LogColumn> columns = new ArrayList<LogColumn>(Arrays.asList(super.passColumns())); if(outbreak instanceof CategoryOutbreak) { for (int i = 0; i < outbreak.size(); i++) { if(outbreak.getCase(i).wasEverInfected()) { final int finalI = i; columns.add(new LogColumn.Abstract("coal_LL_" + i) { protected String getFormattedValue() { return String.valueOf(partitionTreeLogLikelihoods[finalI]); } }); } } columns.add(new LogColumn.Abstract("total_coal_LL") { protected String getFormattedValue() { return String.valueOf(coalescencesLogLikelihood); } }); return columns.toArray(new LogColumn[columns.size()]); } return null; } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public static final String STARTING_NETWORK = "startingNetwork"; public static final String MAX_FIRST_INF_TO_ROOT = "maxFirstInfToRoot"; public static final String DEMOGRAPHIC_MODEL = "demographicModel"; public static final String TRUNCATE = "truncate"; public String getParserName() { return WITHIN_CASE_COALESCENT; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { PartitionedTreeModel virusTree = (PartitionedTreeModel) xo.getChild(TreeModel.class); AbstractOutbreak caseSet = (AbstractOutbreak) xo.getChild(AbstractOutbreak.class); CaseToCaseTreeLikelihood likelihood; Parameter earliestFirstInfection = (Parameter) xo.getElementFirstChild(MAX_FIRST_INF_TO_ROOT); DemographicModel demoModel = (DemographicModel) xo.getElementFirstChild(DEMOGRAPHIC_MODEL); Mode mode = xo.hasAttribute(TRUNCATE) & xo.getBooleanAttribute(TRUNCATE) ? Mode.TRUNCATE : Mode.NORMAL; try { likelihood = new WithinCaseCoalescent(virusTree, caseSet, earliestFirstInfection, demoModel, mode); } catch (TaxonList.MissingTaxonException e) { throw new XMLParseException(e.toString()); } return likelihood; } public String getParserDescription() { return "This element provides a tree prior for a partitioned tree, with each partitioned tree generated" + "by a coalescent process"; } public Class getReturnType() { return WithinCaseCoalescent.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new ElementRule(PartitionedTreeModel.class, "The tree"), new ElementRule(CategoryOutbreak.class, "The set of cases", 0,1), new ElementRule(CategoryOutbreak.class, "The set of cases", 0,1), new ElementRule("startingNetwork", String.class, "A CSV file containing a specified starting network", true), new ElementRule(MAX_FIRST_INF_TO_ROOT, Parameter.class, "The maximum time from the first infection to" + "the root node"), new ElementRule(DEMOGRAPHIC_MODEL, DemographicModel.class, "The demographic model for within-case" + "evolution"), AttributeRule.newBooleanRule(TRUNCATE) }; }; }