/* * CaseToCaseTreeLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.epidemiology.casetocase; import java.io.IOException; import java.io.PrintStream; import java.util.*; import dr.app.tools.NexusExporter; import dr.evolution.tree.FlexibleNode; import dr.evolution.tree.FlexibleTree; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.tree.TreeTrait; import dr.evolution.tree.TreeTraitProvider; import dr.evolution.util.Taxon; import dr.evolution.util.TaxonList; import dr.evomodel.epidemiology.casetocase.periodpriors.AbstractPeriodPriorDistribution; import dr.evomodel.tree.TreeModel; import dr.oldevomodel.treelikelihood.AbstractTreeLikelihood; import dr.inference.loggers.LogColumn; import dr.inference.loggers.Loggable; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.util.Author; import dr.util.Citable; import dr.util.Citation; import org.apache.commons.math.stat.descriptive.moment.Mean; import org.apache.commons.math.stat.descriptive.moment.Variance; import org.apache.commons.math.stat.descriptive.rank.Median; /** * Handles manipulation of the tree partition, and likelihood of the infection times. * * @author Matthew Hall * @author Andrew Rambaut * @version $Id: $ */ public abstract class CaseToCaseTreeLikelihood extends AbstractTreeLikelihood implements Loggable, Citable, TreeTraitProvider { protected static final boolean DEBUG = false; protected static double tolerance = 1E-10; /* The phylogenetic tree. */ protected int noTips; protected int noCases; /* Mapping of outbreak to branches on the tree; old version is stored before operators are applied */ /* Matches outbreak to external nodes */ private double estimatedLastSampleTime; protected TreeTraitProvider.Helper treeTraits = new Helper(); /** * The set of cases */ protected AbstractOutbreak outbreak; protected double[] infectionTimes; private double[] storedInfectionTimes; protected double[] infectiousPeriods; private double[] storedInfectiousPeriods; protected double[] infectiousTimes; private double[] storedInfectiousTimes; protected double[] latentPeriods; private double[] storedLatentPeriods; protected boolean[] recalculateCaseFlags; protected HashMap<AbstractCase,Treelet> elementsAsTrees; protected HashMap<AbstractCase,Treelet> storedElementsAsTrees; //because of the way the former works, we need a maximum value of the time from first infection to root node. protected Parameter maxFirstInfToRoot; // latent periods protected boolean hasLatentPeriods; // PUBLIC STUFF // Name public static final String CASE_TO_CASE_TREE_LIKELIHOOD = "caseToCaseTreeLikelihood"; public static final String PARTITIONS_KEY = "partition"; // Basic constructor. public CaseToCaseTreeLikelihood(PartitionedTreeModel tree, AbstractOutbreak caseData, Parameter maxFirstInfToRoot) throws TaxonList.MissingTaxonException { this(CASE_TO_CASE_TREE_LIKELIHOOD, tree, caseData, maxFirstInfToRoot); } // Constructor for an instance with a non-default name public CaseToCaseTreeLikelihood(String name, PartitionedTreeModel tree, AbstractOutbreak caseData, Parameter maxFirstInfToRoot) { super(name, caseData, tree); if(stateCount!=treeModel.getExternalNodeCount()){ throw new RuntimeException("There are duplicate tip outbreak."); } noTips = tree.getExternalNodeCount(); //subclasses should add outbreak as a model if it contains any information that ever changes outbreak = caseData; noCases = outbreak.getCases().size(); addModel(outbreak); estimatedLastSampleTime = getLatestTaxonTime(); //map outbreak to tips addModel(tree.getBranchMap()); hasLatentPeriods = outbreak.hasLatentPeriods(); infectionTimes = new double[outbreak.size()]; infectiousPeriods = new double[outbreak.size()]; if(hasLatentPeriods){ infectiousTimes = new double[outbreak.size()]; latentPeriods = new double[outbreak.size()]; } recalculateCaseFlags = new boolean[outbreak.size()]; Arrays.fill(recalculateCaseFlags, true); this.maxFirstInfToRoot = maxFirstInfToRoot; treeTraits.addTrait(PARTITIONS_KEY, new TreeTrait.S() { public String getTraitName() { return PARTITIONS_KEY; } public Intent getIntent() { return Intent.NODE; } public String getTrait(Tree tree, NodeRef node) { return getNodePartition(tree, node); } }); if(DEBUG){ treeTraits.addTrait("NodeNumber", new TreeTrait.S() { public String getTraitName() { return "NodeNumber"; } public Intent getIntent() { return Intent.NODE; } public String getTrait(Tree tree, NodeRef node) { return Integer.toString(node.getNumber()); } }); } likelihoodKnown = false; } public AbstractOutbreak getOutbreak(){ return outbreak; } public boolean hasLatentPeriods(){ return hasLatentPeriods; } /* Get the date of the last tip */ private double getLatestTaxonTime(){ double latestTime = Double.NEGATIVE_INFINITY; for(int i=0; i<treeModel.getExternalNodeCount(); i++){ Taxon taxon = treeModel.getNodeTaxon(treeModel.getExternalNode(i)); if(taxon.getDate().getTimeValue() > latestTime){ latestTime = taxon.getDate().getTimeValue(); } } return latestTime; } private NodeRef[] getChildren(NodeRef node){ NodeRef[] children = new NodeRef[treeModel.getChildCount(node)]; for(int i=0; i< treeModel.getChildCount(node); i++){ children[i] = treeModel.getChild(node,i); } return children; } protected void explodeTree(){ for(int i=0; i<outbreak.size(); i++){ AbstractCase aCase = outbreak.getCase(i); if(aCase.wasEverInfected() && elementsAsTrees.get(aCase)==null){ NodeRef partitionRoot = ((PartitionedTreeModel)treeModel).getEarliestNodeInElement(aCase); double extraHeight; if(treeModel.isRoot(partitionRoot)){ extraHeight = maxFirstInfToRoot.getParameterValue(0) * aCase.getInfectionBranchPosition().getParameterValue(0); } else { extraHeight = treeModel.getBranchLength(partitionRoot) * aCase.getInfectionBranchPosition().getParameterValue(0); } FlexibleNode newRoot = new FlexibleNode(); FlexibleTree littleTree = new FlexibleTree(newRoot); littleTree.beginTreeEdit(); if (!treeModel.isExternal(partitionRoot)) { for (int j = 0; j < treeModel.getChildCount(partitionRoot); j++) { copyElementToTreelet(littleTree, treeModel.getChild(partitionRoot, j), newRoot, aCase); } } littleTree.endTreeEdit(); littleTree.resolveTree(); Treelet treelet = new Treelet(littleTree, littleTree.getRootHeight() + extraHeight); elementsAsTrees.put(aCase, treelet); } } } private void copyElementToTreelet(FlexibleTree littleTree, NodeRef oldNode, NodeRef newParent, AbstractCase element){ if(element.wasEverInfected()) { if (getBranchMap().get(oldNode.getNumber()) == element) { if (treeModel.isExternal(oldNode)) { NodeRef newTip = new FlexibleNode(new Taxon(treeModel.getNodeTaxon(oldNode).getId())); littleTree.addChild(newParent, newTip); littleTree.setBranchLength(newTip, treeModel.getBranchLength(oldNode)); } else { NodeRef newChild = new FlexibleNode(); littleTree.addChild(newParent, newChild); littleTree.setBranchLength(newChild, treeModel.getBranchLength(oldNode)); for (int i = 0; i < treeModel.getChildCount(oldNode); i++) { copyElementToTreelet(littleTree, treeModel.getChild(oldNode, i), newChild, element); } } } else { // we need a new tip NodeRef transmissionTip = new FlexibleNode( new Taxon("Transmission_" + getBranchMap().get(oldNode.getNumber()).getName())); double parentTime = getNodeTime(treeModel.getParent(oldNode)); double childTime = getInfectionTime(getBranchMap().get(oldNode.getNumber())); littleTree.addChild(newParent, transmissionTip); littleTree.setBranchLength(transmissionTip, childTime - parentTime); } } } protected class Treelet extends FlexibleTree { private double zeroHeight; protected Treelet(FlexibleTree tree, double zeroHeight){ super(tree); this.zeroHeight = zeroHeight; } protected double getZeroHeight(){ return zeroHeight; } protected void setZeroHeight(double rootBranchLength){ this.zeroHeight = zeroHeight; } } // find all partitions of the descendant tips of the current node. If map is specified then it makes a map of node // number to possible partitions; map can be null. public HashSet<AbstractCase> descendantTipPartitions(NodeRef node, HashMap<Integer, HashSet<AbstractCase>> map){ HashSet<AbstractCase> out = new HashSet<AbstractCase>(); if(treeModel.isExternal(node)){ out.add(getBranchMap().get(node.getNumber())); if(map!=null){ map.put(node.getNumber(), out); } return out; } else { for(int i=0; i< treeModel.getChildCount(node); i++){ out.addAll(descendantTipPartitions(treeModel.getChild(node, i), map)); } if(map!=null){ map.put(node.getNumber(), out); } return out; } } // change flags to indicate that something needs recalculation further down the tree protected static void flagForDescendantRecalculation(TreeModel tree, NodeRef node, boolean[] flags){ flags[node.getNumber()]=true; for(int i=0; i<tree.getChildCount(node); i++){ flags[tree.getChild(node,i).getNumber()]=true; } NodeRef currentNode=node; while(!tree.isRoot(currentNode) && !flags[currentNode.getNumber()]){ currentNode = tree.getParent(currentNode); flags[currentNode.getNumber()]=true; } } public void flagForDescendantRecalculation(TreeModel tree, NodeRef node){ flagForDescendantRecalculation(tree, node, updateNode); } // ************************************************************** // ModelListener IMPLEMENTATION // ************************************************************** protected void handleModelChangedEvent(Model model, Object object, int index) { if(!(model instanceof AbstractPeriodPriorDistribution)){ if (model == treeModel) { if(object instanceof PartitionedTreeModel.PartitionsChangedEvent){ HashSet<AbstractCase> changedPartitions = ((PartitionedTreeModel.PartitionsChangedEvent)object).getCasesToRecalculate(); for(AbstractCase aCase : changedPartitions){ recalculateCase(aCase); } } } else if (model == getBranchMap()){ if(object instanceof ArrayList){ for(int i=0; i<((ArrayList) object).size(); i++){ BranchMapModel.BranchMapChangedEvent event = (BranchMapModel.BranchMapChangedEvent)((ArrayList) object).get(i); recalculateCase(event.getOldCase()); recalculateCase(event.getNewCase()); NodeRef node = treeModel.getNode(event.getNodeToRecalculate()); NodeRef parent = treeModel.getParent(node); if(parent!=null){ recalculateCase(getBranchMap().get(parent.getNumber())); } } } else { throw new RuntimeException("Unanticipated model changed event from BranchMapModel"); } } else if (model == outbreak){ if(object instanceof AbstractCase){ recalculateCase((AbstractCase)object); } else { for (AbstractCase aCase : outbreak.getCases()) { recalculateCase(aCase); } } } fireModelChanged(model); likelihoodKnown = false; } } protected void recalculateCase(int index){ recalculateCaseFlags[index] = true; } protected void recalculateCase(AbstractCase aCase){ if(aCase.wasEverInfected()) { recalculateCase(outbreak.getCaseIndex(aCase)); } } // ************************************************************** // VariableListener IMPLEMENTATION // ************************************************************** protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { fireModelChanged(); likelihoodKnown = false; } // ************************************************************** // Model IMPLEMENTATION // ************************************************************** /** * Stores the precalculated state (in this case the node labels and subtree likelihoods) */ protected void storeState() { super.storeState(); storedInfectionTimes = Arrays.copyOf(infectionTimes, infectionTimes.length); storedInfectiousPeriods = Arrays.copyOf(infectiousPeriods, infectiousPeriods.length); if(hasLatentPeriods){ storedInfectiousTimes = Arrays.copyOf(infectiousTimes, infectionTimes.length); storedLatentPeriods = Arrays.copyOf(latentPeriods, latentPeriods.length); } } /** * Restores the precalculated state. */ protected void restoreState() { super.restoreState(); infectionTimes = storedInfectionTimes; infectiousPeriods = storedInfectiousPeriods; if(hasLatentPeriods){ infectiousTimes = storedInfectiousTimes; latentPeriods = storedLatentPeriods; } } protected final void acceptState() { } // nothing to do // ************************************************************** // Likelihood IMPLEMENTATION // ************************************************************** public final BranchMapModel getBranchMap(){ return ((PartitionedTreeModel)treeModel).getBranchMap(); } public final PartitionedTreeModel getTreeModel(){ return (PartitionedTreeModel)treeModel; } public void makeDirty() { likelihoodKnown = false; Arrays.fill(recalculateCaseFlags, true); } protected void prepareTimings(){ infectionTimes = getInfectionTimes(true); if(hasLatentPeriods){ infectiousTimes = getInfectiousTimes(true); } infectiousPeriods = getInfectiousPeriods(true); if(hasLatentPeriods){ latentPeriods = getLatentPeriods(true); } Arrays.fill(recalculateCaseFlags, false); } /** * Calculates the log likelihood of this set of node labels given the tree. */ protected abstract double calculateLogLikelihood(); // if no infectious models, just need to check whether any infections occur after the infector was no // longer infectious protected boolean isAllowed(){ return isAllowed(treeModel.getRoot()); } private boolean isAllowed(NodeRef node){ if(!treeModel.isRoot(node)){ AbstractCase childCase = getBranchMap().get(node.getNumber()); AbstractCase parentCase = getBranchMap().get(treeModel.getParent(node).getNumber()); if(childCase!=parentCase){ double infectionTime = infectionTimes[outbreak.getCaseIndex(childCase)]; if(infectionTime>parentCase.getEndTime() || (hasLatentPeriods && infectionTime<infectiousTimes[outbreak.getCaseIndex(parentCase)])){ return false; } } } return treeModel.isExternal(node) || (isAllowed(treeModel.getChild(node, 0)) && isAllowed(treeModel.getChild(node, 1))); } /* Return the double time at which the given node occurred */ public double getNodeTime(NodeRef node){ double nodeHeight = getHeight(node); return estimatedLastSampleTime-nodeHeight; } public double heightToTime(double height){ return estimatedLastSampleTime-height; } public double timeToHeight(double time){ return estimatedLastSampleTime-time; } private double getHeight(NodeRef node){ return treeModel.getNodeHeight(node); } public double getInfectionTime(AbstractCase thisCase){ if(!recalculateCaseFlags[outbreak.getCaseIndex(thisCase)]){ return infectionTimes[outbreak.getCaseIndex(thisCase)]; } else { if(thisCase.wasEverInfected()) { NodeRef child = ((PartitionedTreeModel)treeModel).getEarliestNodeInElement(thisCase); NodeRef parent = treeModel.getParent(child); if (parent != null) { double min = heightToTime(treeModel.getNodeHeight(parent)); // Let the likelihood evaluate to zero due to culling dates if it must... double max = heightToTime(treeModel.getNodeHeight(child)); return getInfectionTime(min, max, thisCase); } else { return getRootInfectionTime(getBranchMap()); } } else { return Double.POSITIVE_INFINITY; } } } private double getInfectionTime(double min, double max, AbstractCase infected){ final double branchLength = max-min; return min + branchLength*(1-infected.getInfectionBranchPosition().getParameterValue(0)); } public double[] getInfectionTimes(boolean recalculate){ if(recalculate) { for(int i=0; i<noCases; i++){ if(recalculateCaseFlags[i]){ infectionTimes[i] = getInfectionTime(outbreak.getCase(i)); } } } return infectionTimes; } public void setInfectionTime(AbstractCase thisCase, double time){ setInfectionHeight(thisCase, timeToHeight(time)); } public void setInfectionHeight(AbstractCase thisCase, double height){ if(thisCase.wasEverInfected()) { NodeRef child = ((PartitionedTreeModel)treeModel).getEarliestNodeInElement(thisCase); NodeRef parent = treeModel.getParent(child); double minHeight = treeModel.getNodeHeight(child); double maxHeight = parent != null ? treeModel.getNodeHeight(parent) : minHeight + maxFirstInfToRoot.getParameterValue(0); if (height < minHeight || height > maxHeight) { throw new RuntimeException("Trying to set an infection time outside the branch on which it must occur"); } double branchPosition = (height - minHeight) / (maxHeight - minHeight); thisCase.setInfectionBranchPosition(branchPosition); } } public double getInfectiousTime(AbstractCase thisCase){ if(!hasLatentPeriods){ return getInfectionTime(thisCase); } else { if (recalculateCaseFlags[outbreak.getCaseIndex(thisCase)]) { if(thisCase.wasEverInfected()) { String latentCategory = ((CategoryOutbreak) outbreak).getLatentCategory(thisCase); Parameter latentPeriod = ((CategoryOutbreak) outbreak).getLatentPeriod(latentCategory); infectiousTimes[outbreak.getCaseIndex(thisCase)] = getInfectionTime(thisCase) + latentPeriod.getParameterValue(0); } else { infectiousTimes[outbreak.getCaseIndex(thisCase)] = Double.POSITIVE_INFINITY; } } return infectiousTimes[outbreak.getCaseIndex(thisCase)]; } } public double[] getInfectiousTimes(boolean recalculate){ if(recalculate){ for(int i=0; i<noCases; i++){ if(recalculateCaseFlags[i]){ infectiousTimes[i] = getInfectiousTime(outbreak.getCase(i)); } } } return infectiousTimes; } public double getInfectiousPeriod(AbstractCase thisCase){ if(recalculateCaseFlags[outbreak.getCaseIndex(thisCase)]){ if(thisCase.wasEverInfected()) { if (!hasLatentPeriods) { double infectionTime = getInfectionTime(thisCase); double cullTime = thisCase.getEndTime(); infectiousPeriods[outbreak.getCaseIndex(thisCase)] = cullTime - infectionTime; } else { double infectiousTime = getInfectiousTime(thisCase); double cullTime = thisCase.getEndTime(); infectiousPeriods[outbreak.getCaseIndex(thisCase)] = cullTime - infectiousTime; } } else { infectiousPeriods[outbreak.getCaseIndex(thisCase)] = 0; } } return infectiousPeriods[outbreak.getCaseIndex(thisCase)]; } public double[] getInfectiousPeriods(boolean recalculate){ if(recalculate){ for(int i=0; i<noCases; i++){ if(recalculateCaseFlags[i]){ infectiousPeriods[i] = getInfectiousPeriod(outbreak.getCase(i)); } } } return infectiousPeriods; } public Double[] getNonzeroInfectiousPeriods(){ ArrayList<Double> out = new ArrayList<Double>(); for(int i=0; i<noCases; i++){ AbstractCase thisCase = outbreak.getCase(i); if(thisCase.wasEverInfected()){ out.add(getInfectiousPeriod(thisCase)); } } return out.toArray(new Double[out.size()]); } public double getLatentPeriod(AbstractCase thisCase){ if(!hasLatentPeriods || !thisCase.wasEverInfected()){ return 0.0; } if(recalculateCaseFlags[outbreak.getCaseIndex(thisCase)]){ latentPeriods[outbreak.getCaseIndex(thisCase)] = getInfectiousTime(thisCase) - getInfectionTime(thisCase); } return latentPeriods[outbreak.getCaseIndex(thisCase)]; } public double[] getLatentPeriods(boolean recalculate){ if(recalculate){ for(int i=0; i<noCases; i++){ if(recalculateCaseFlags[i]){ latentPeriods[i] = getLatentPeriod(outbreak.getCase(i)); } } } return latentPeriods; } public Double[] getNonzeroLatentPeriods(){ ArrayList<Double> out = new ArrayList<Double>(); for(int i=0; i<noCases; i++){ AbstractCase thisCase = outbreak.getCase(i); if(thisCase.wasEverInfected()){ out.add(getLatentPeriod(thisCase)); } } return out.toArray(new Double[out.size()]); } public double[] getInfectedPeriods(boolean recalculate){ if(!hasLatentPeriods){ return getInfectiousPeriods(recalculate); } else { double[] out = new double[noCases]; for(int i=0; i<noCases; i++){ out[i] = getInfectedPeriod(outbreak.getCase(i)); } return out; } } public Double[] getNonzeroInfectedPeriods(){ ArrayList<Double> out = new ArrayList<Double>(); for(int i=0; i<noCases; i++){ AbstractCase thisCase = outbreak.getCase(i); if(thisCase.wasEverInfected()){ out.add(getInfectedPeriod(thisCase)); } } return out.toArray(new Double[out.size()]); } public double getInfectedPeriod(AbstractCase thisCase){ if(thisCase.wasEverInfected) { return thisCase.getEndTime() - getInfectionTime(thisCase); } return 0; } // return an array of the mean, median, variance and standard deviation of the given array // @todo this is pretty wasteful since it gets called so many times per log entry public static Double[] getSummaryStatistics(Double[] variable){ double[] primitiveVariable = new double[variable.length]; for(int i=0; i<variable.length; i++){ primitiveVariable[i] = variable[i]; } Double[] out = new Double[4]; out[0] = (new Mean()).evaluate(primitiveVariable); out[1] = (new Median()).evaluate(primitiveVariable); out[2] = (new Variance()).evaluate(primitiveVariable); out[3] = Math.sqrt(out[2]); return out; } private double getRootInfectionTime(BranchMapModel branchMap){ NodeRef root = treeModel.getRoot(); AbstractCase rootCase = branchMap.get(root.getNumber()); final double branchLength = maxFirstInfToRoot.getParameterValue(0); return heightToTime(treeModel.getNodeHeight(root) + branchLength * rootCase.getInfectionBranchPosition().getParameterValue(0)); } protected double getRootInfectionTime(){ AbstractCase rootCase = getBranchMap().get(treeModel.getRoot().getNumber()); return getInfectionTime(rootCase); } public void outputTreeToFile(String fileName, boolean includeTransmissionNodes){ outputTreeToFile(getBranchMap(), fileName, includeTransmissionNodes); } public void outputTreeToFile(BranchMapModel map, String fileName, boolean includeTransmissionNodes){ try{ FlexibleTree treeCopy; if(!includeTransmissionNodes){ treeCopy = new FlexibleTree(treeModel); for(int j=0; j<treeCopy.getNodeCount(); j++){ FlexibleNode node = (FlexibleNode)treeCopy.getNode(j); node.setAttribute("Number", node.getNumber()); node.setAttribute("Time", heightToTime(node.getHeight())); node.setAttribute(PARTITIONS_KEY, map.get(node.getNumber())); } } else { treeCopy = addTransmissionNodes(treeModel); } NexusExporter testTreesOut = new NexusExporter(new PrintStream(fileName)); testTreesOut.exportTree(treeCopy); } catch (IOException ignored) {System.out.println("IOException");} } public FlexibleTree addTransmissionNodes(Tree tree){ prepareTimings(); FlexibleTree outTree = new FlexibleTree(tree, true); for(int j=0; j<outTree.getNodeCount(); j++){ FlexibleNode node = (FlexibleNode)outTree.getNode(j); node.setAttribute("Number", node.getNumber()); node.setAttribute("Time", heightToTime(node.getHeight())); node.setAttribute(PARTITIONS_KEY, getBranchMap().get(node.getNumber())); } for(AbstractCase aCase : outbreak.getCases()){ if(aCase.wasEverInfected()) { NodeRef originalNode = ((PartitionedTreeModel)treeModel).getEarliestNodeInElement(aCase); int infectionNodeNo = originalNode.getNumber(); if (!treeModel.isRoot(originalNode)) { NodeRef originalParent = treeModel.getParent(originalNode); double nodeTime = getNodeTime(originalNode); double infectionTime = getInfectionTime(aCase); double heightToBreakBranch = getHeight(originalNode) + (nodeTime - infectionTime); FlexibleNode newNode = (FlexibleNode) outTree.getNode(infectionNodeNo); FlexibleNode oldParent = (FlexibleNode) outTree.getParent(newNode); outTree.beginTreeEdit(); outTree.removeChild(oldParent, newNode); FlexibleNode infectionNode = new FlexibleNode(); infectionNode.setHeight(heightToBreakBranch); infectionNode.setLength(oldParent.getHeight() - heightToBreakBranch); infectionNode.setAttribute(PARTITIONS_KEY, getNodePartition(treeModel, originalParent)); infectionNode.setAttribute("Time", heightToTime(heightToBreakBranch)); newNode.setLength(nodeTime - infectionTime); outTree.addChild(oldParent, infectionNode); outTree.addChild(infectionNode, newNode); outTree.endTreeEdit(); } else { double nodeTime = getNodeTime(originalNode); double infectionTime = getInfectionTime(aCase); double heightToInstallRoot = getHeight(originalNode) + (nodeTime - infectionTime); FlexibleNode newNode = (FlexibleNode) outTree.getNode(infectionNodeNo); outTree.beginTreeEdit(); FlexibleNode infectionNode = new FlexibleNode(); infectionNode.setHeight(heightToInstallRoot); infectionNode.setAttribute("Time", heightToTime(heightToInstallRoot)); infectionNode.setAttribute(PARTITIONS_KEY, "Origin"); outTree.addChild(infectionNode, newNode); newNode.setLength(heightToInstallRoot - getHeight(originalNode)); outTree.setRoot(infectionNode); outTree.endTreeEdit(); } } } outTree = new FlexibleTree((FlexibleNode)outTree.getRoot()); for(int i=0; i<outTree.getNodeCount(); i++){ NodeRef node = outTree.getNode(i); NodeRef parent = outTree.getParent(node); if(parent!=null && outTree.getNodeHeight(node)>outTree.getNodeHeight(parent)){ try{ NexusExporter exporter = new NexusExporter(new PrintStream("fancyProblem.nex")); exporter.exportTree(outTree); } catch(IOException e){ e.printStackTrace(); } try{ ((PartitionedTreeModel)treeModel).checkPartitions(); } catch(BadPartitionException e){ System.out.print("Rewiring messed up because of partition problem."); } throw new RuntimeException("Rewiring messed up; investigate"); } } return outTree; } //************************************************************************ // Loggable implementation //************************************************************************ public LogColumn[] getColumns(){ LogColumn[] columns = new LogColumn[outbreak.infectedSize()]; int count = 0; for(int i=0; i<outbreak.size(); i++){ final AbstractCase infected = outbreak.getCase(i); if(infected.wasEverInfected()) { columns[count] = new LogColumn.Abstract(infected.toString() + "_infector") { protected String getFormattedValue() { if (((PartitionedTreeModel)treeModel).getInfector(infected) == null) { return "Start"; } else { return ((PartitionedTreeModel)treeModel).getInfector(infected).toString(); } } }; count++; } } return columns; } public LogColumn[] passColumns(){ ArrayList<LogColumn> columns = new ArrayList<LogColumn>(); for(int i=0; i< outbreak.size(); i++){ final AbstractCase infected = outbreak.getCase(i); if(infected.wasEverInfected()) { columns.add(new LogColumn.Abstract(infected.toString() + "_infection_date") { protected String getFormattedValue() { return String.valueOf(getInfectionTime(infected)); } }); } } if(hasLatentPeriods){ for(int i=0; i< outbreak.size(); i++){ final AbstractCase infected = outbreak.getCase(i); if(infected.wasEverInfected()) { columns.add(new LogColumn.Abstract(infected.toString() + "_infectious_date") { protected String getFormattedValue() { return String.valueOf(getInfectiousTime(infected)); } }); } } for(int i=0; i< outbreak.size(); i++){ final AbstractCase infected = outbreak.getCase(i); if(infected.wasEverInfected()) { columns.add(new LogColumn.Abstract(infected.toString() + "_latent_period") { protected String getFormattedValue() { return String.valueOf(getLatentPeriod(infected)); } }); } } } for(int i=0; i< outbreak.size(); i++){ final AbstractCase infected = outbreak.getCase(i); if(infected.wasEverInfected()) { columns.add(new LogColumn.Abstract(infected.toString() + "_infectious_period") { protected String getFormattedValue() { return String.valueOf(getInfectiousPeriod(infected)); } }); } } if(hasLatentPeriods){ for(int i=0; i< outbreak.size(); i++){ final AbstractCase infected = outbreak.getCase(i); if(infected.wasEverInfected()) { columns.add(new LogColumn.Abstract(infected.toString() + "_infected_period") { protected String getFormattedValue() { return String.valueOf( getInfectiousPeriod(infected) + getLatentPeriod(infected)); } }); } } } columns.add(new LogColumn.Abstract("infectious_period.mean"){ protected String getFormattedValue() { return String.valueOf(CaseToCaseTreeLikelihood .getSummaryStatistics(getNonzeroInfectiousPeriods())[0]); } }); columns.add(new LogColumn.Abstract("infectious_period.median"){ protected String getFormattedValue() { return String.valueOf(CaseToCaseTreeLikelihood .getSummaryStatistics(getNonzeroInfectiousPeriods())[1]); } }); columns.add(new LogColumn.Abstract("infectious_period.var") { protected String getFormattedValue() { return String.valueOf(CaseToCaseTreeLikelihood .getSummaryStatistics(getNonzeroInfectiousPeriods())[2]); } }); columns.add(new LogColumn.Abstract("infectious_period.stdev"){ protected String getFormattedValue() { return String.valueOf(CaseToCaseTreeLikelihood .getSummaryStatistics(getNonzeroInfectiousPeriods())[3]); } }); if(hasLatentPeriods){ columns.add(new LogColumn.Abstract("latent_period.mean"){ protected String getFormattedValue() { return String.valueOf(CaseToCaseTreeLikelihood .getSummaryStatistics(getNonzeroLatentPeriods())[0]); } }); columns.add(new LogColumn.Abstract("latent_period.median"){ protected String getFormattedValue() { return String.valueOf(CaseToCaseTreeLikelihood .getSummaryStatistics(getNonzeroLatentPeriods())[1]); } }); columns.add(new LogColumn.Abstract("latent_period.var") { protected String getFormattedValue() { return String.valueOf(CaseToCaseTreeLikelihood .getSummaryStatistics(getNonzeroLatentPeriods())[2]); } }); columns.add(new LogColumn.Abstract("latent_period.stdev"){ protected String getFormattedValue() { return String.valueOf(CaseToCaseTreeLikelihood .getSummaryStatistics(getNonzeroLatentPeriods())[3]); } }); columns.add(new LogColumn.Abstract("infected_period.mean"){ protected String getFormattedValue() { return String.valueOf(CaseToCaseTreeLikelihood .getSummaryStatistics(getNonzeroInfectedPeriods())[0]); } }); columns.add(new LogColumn.Abstract("infected_period.median"){ protected String getFormattedValue() { return String.valueOf(CaseToCaseTreeLikelihood .getSummaryStatistics(getNonzeroInfectedPeriods())[1]); } }); columns.add(new LogColumn.Abstract("infected_period.var") { protected String getFormattedValue() { return String.valueOf(CaseToCaseTreeLikelihood .getSummaryStatistics(getNonzeroInfectedPeriods())[2]); } }); columns.add(new LogColumn.Abstract("infected_period.stdev"){ protected String getFormattedValue() { return String.valueOf(CaseToCaseTreeLikelihood .getSummaryStatistics(getNonzeroInfectedPeriods())[3]); } }); for(int i=0; i< outbreak.size(); i++){ final AbstractCase infected = outbreak.getCase(i); if(infected.wasEverInfected()) { columns.add(new LogColumn.Abstract(infected.toString() + "_ibp") { protected String getFormattedValue() { return String.valueOf(infected.getInfectionBranchPosition().getParameterValue(0)); } }); } } } return columns.toArray(new LogColumn[columns.size()]); } @Override public Citation.Category getCategory() { return Citation.Category.TREE_PRIORS; } @Override public String getDescription() { return "Case to Case Transmission Tree model"; } public List<Citation> getCitations() { return Arrays.asList(new Citation( new Author[]{new Author("M", "Hall"), new Author("M", "Woolhouse"), new Author("A", "Rambaut")}, "Epidemic Reconstruction in a Phylogenetics Framework: Transmission Trees as Partitions of the Node Set", 2016, "PLOS Comput Biol", 11, 0, 0, "10.1371/journal.pcbi.1004613", Citation.Status.PUBLISHED)); } // ************************************************************** // TreeTraitProvider IMPLEMENTATION // ************************************************************** public TreeTrait[] getTreeTraits() { return treeTraits.getTreeTraits(); } public TreeTrait getTreeTrait(String key) { return treeTraits.getTreeTrait(key); } public String getNodePartition(Tree tree, NodeRef node) { if (tree != treeModel) { // we're trying to annotate a partitioned tree, we hope try{ NodeRef oldNode = treeModel.getNode((Integer)tree.getNodeAttribute(node,"Number")); if(treeModel.getNodeHeight(oldNode)!=tree.getNodeHeight(node)){ throw new RuntimeException("Can only reconstruct states on treeModel given to constructor or a " + "partitioned tree derived from it"); } else { return getBranchMap().get(oldNode.getNumber()).toString(); } } catch(NullPointerException e){ if(tree.isRoot(node)){ return "Start"; } else { NodeRef parent = tree.getParent(node); int originalParentNumber = (Integer)tree.getNodeAttribute(parent,"Number"); return getBranchMap().get(originalParentNumber).toString(); } } } else { return getBranchMap().get(node.getNumber()).toString(); } } public Integer[] getParentsArray(){ Integer[] out = new Integer[outbreak.size()]; for(AbstractCase thisCase : outbreak.getCases()){ if(thisCase.wasEverInfected()) { out[outbreak.getCaseIndex(thisCase)] = outbreak.getCaseIndex(((PartitionedTreeModel)treeModel).getInfector(thisCase)); } else { out[outbreak.getCaseIndex(thisCase)] = null; } } return out; } public AbstractCase getInfector(int i){ return ((PartitionedTreeModel)treeModel).getInfector(getOutbreak().getCase(i)); } }