/*
* PartitionedTreeModel.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.epidemiology.casetocase;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.Taxon;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.*;
import dr.math.MathUtils;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.*;
/**
* TreeModel plus partition information
*
* todo a lot of methods should eventually move here
*/
public class PartitionedTreeModel extends TreeModel {
private final AbstractOutbreak outbreak;
private BranchMapModel branchMap;
private final int elementCount;
public final static String PARTITIONED_TREE_MODEL = "partitionedTreeModel";
Set<NodeRef> partitionsQueue = new HashSet<NodeRef>();
public PartitionedTreeModel(String id, Tree tree, AbstractOutbreak outbreak){
super(id, tree);
this.outbreak = outbreak;
elementCount = outbreak.infectedSize();
branchMap = new BranchMapModel(this);
partitionAccordingToRandomTT(false);
}
public PartitionedTreeModel(String id, Tree tree, AbstractOutbreak outbreak, String startingTTFileName){
super(id, tree);
this.outbreak = outbreak;
elementCount = outbreak.infectedSize();
branchMap = new BranchMapModel(this);
partitionAccordingToSpecificTT(startingTTFileName);
}
public PartitionedTreeModel(TreeModel treeModel, AbstractOutbreak outbreak){
this(PARTITIONED_TREE_MODEL, treeModel, outbreak);
}
public PartitionedTreeModel(TreeModel treeModel, AbstractOutbreak outbreak, String startingTTFileName){
this(PARTITIONED_TREE_MODEL, treeModel, outbreak, startingTTFileName);
}
public void partitionsChangingAlert(HashSet<AbstractCase> casesToRecalculate){
// TreeLikelihood and TreeParameter listeners are irrelevant
listenerHelper.fireModelChanged(this, new PartitionsChangedEvent(casesToRecalculate));
}
public void partitionChangingAlert(AbstractCase caseToRecalculate){
HashSet<AbstractCase> out = new HashSet<AbstractCase>();
out.add(caseToRecalculate);
partitionsChangingAlert(out);
}
public void universalAlert(){
HashSet<AbstractCase> allCases = new HashSet<AbstractCase>(Arrays.asList(branchMap.getArrayCopy()));
partitionsChangingAlert(allCases);
}
public BranchMapModel getBranchMap(){
return branchMap;
}
protected void handleModelChangedEvent(Model model, Object object, int index) {
// shouldn't be any
}
public class PartitionsChangedEvent{
private final HashSet<AbstractCase> casesToRecalculate;
public PartitionsChangedEvent(HashSet<AbstractCase> casesToRecalculate){
this.casesToRecalculate = casesToRecalculate;
}
public HashSet<AbstractCase> getCasesToRecalculate(){
return casesToRecalculate;
}
}
public void pushNodePartitionsChangedEvent(NodeRef node){
int nodeNumber = node.getNumber();
if(!inTreeEdit()){
partitionChangingAlert(branchMap.get(nodeNumber));
} else {
partitionsQueue.add(node);
}
}
public void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
super.handleVariableChangedEvent(variable, index, type);
if (type == Parameter.ChangeType.ALL_VALUES_CHANGED) {
//this signals events where values in all dimensions of a parameter is changed.
universalAlert();
} else {
final NodeRef node = getNodeOfParameter((Parameter) variable);
partitionsChangingAlert(adjacentElements(node));
}
}
public HashSet<AbstractCase> adjacentElements(NodeRef node){
HashSet<AbstractCase> changedCases = new HashSet<AbstractCase>();
ArrayList<NodeRef> affectedNodes = new ArrayList<NodeRef>();
affectedNodes.add(node);
affectedNodes.add(getParent(node));
affectedNodes.add(getChild(node, 0));
affectedNodes.add(getChild(node, 1));
for(NodeRef aNode : affectedNodes){
if(aNode!=null){
changedCases.add(branchMap.get(aNode.getNumber()));
}
}
return changedCases;
}
private void flushQueue(){
if(inTreeEdit()){
throw new RuntimeException("Wait until you've finished editing the tree before flushing the partition" +
"queue");
}
for(NodeRef node : partitionsQueue){
AbstractCase nodeElement = branchMap.get(node.getNumber());
partitionChangingAlert(nodeElement);
NodeRef parent = getParent(node);
if(parent!=null && branchMap.get(node.getNumber())!=branchMap.get(parent.getNumber())){
partitionChangingAlert(branchMap.get(parent.getNumber()));
}
}
partitionsQueue.clear();
}
public void addChild(NodeRef p, NodeRef c) {
pushNodePartitionsChangedEvent(p);
pushNodePartitionsChangedEvent(c);
super.addChild(p, c);
}
public void removeChild(NodeRef p, NodeRef c) {
pushNodePartitionsChangedEvent(p);
pushNodePartitionsChangedEvent(c);
super.removeChild(p, c);
}
public void setNodeHeight(NodeRef n, double height) {
partitionsChangingAlert(adjacentElements(n));
super.setNodeHeight(n, height);
}
// anything you do to the partitions must be finished before you call this.
public void endTreeEdit() {
super.endTreeEdit();
// todo in the end, want to check the tree partitions are sane here before flushing the queue
flushQueue();
}
public boolean checkPartitions(){
return checkPartitions(branchMap, true);
}
protected boolean checkPartitions(BranchMapModel map, boolean verbose){
boolean foundProblem = false;
for(int i=0; i<getInternalNodeCount(); i++){
boolean foundTip = false;
for(Integer nodeNumber : samePartitionElement(getInternalNode(i))){
if(isExternal(getNode(nodeNumber))){
foundTip = true;
}
}
if(!foundProblem && !foundTip){
foundProblem = true;
if(verbose){
System.out.println("Node "+(i+getExternalNodeCount()) + " is not connected to a tip");
}
}
}
// @todo wasteful - something accessible should keep a list of cases
for(int i=0; i<getExternalNodeCount(); i++){
AbstractCase aCase = branchMap.get(i);
NodeRef tipMRCA = caseMRCA(aCase);
if(branchMap.get(tipMRCA.getNumber())!=aCase){
throw new BadPartitionException("Node partition disconnected");
}
}
return !foundProblem;
}
//Return a set of nodes that are not descendants of (or equal to) the current node and are in the same partition as
// it.
public HashSet<Integer> samePartitionElementUpTree(NodeRef node){
HashSet<Integer> out = new HashSet<Integer>();
AbstractCase elementCase = branchMap.get(node.getNumber());
NodeRef currentNode = node;
NodeRef parentNode = getParent(node);
while(parentNode!=null && branchMap.get(parentNode.getNumber())==elementCase){
out.add(parentNode.getNumber());
if(countChildrenInSameElement(parentNode)==2){
NodeRef otherChild = sibling(this, currentNode);
out.add(otherChild.getNumber());
out.addAll(samePartitionElementDownTree(otherChild));
}
currentNode = parentNode;
parentNode = getParent(currentNode);
}
return out;
}
//Return a set of nodes that are descendants (and not equal to) the current node and are in the same partition as
// it.
public HashSet<Integer> samePartitionElementDownTree(NodeRef node){
HashSet<Integer> out = new HashSet<Integer>();
AbstractCase elementCase = branchMap.get(node.getNumber());
for(int i=0; i< getChildCount(node); i++){
if(branchMap.get(getChild(node,i).getNumber())==elementCase){
out.add(getChild(node,i).getNumber());
out.addAll(samePartitionElementDownTree(getChild(node, i)));
}
}
return out;
}
public Integer[] samePartitionElement(NodeRef node){
HashSet<Integer> out = new HashSet<Integer>();
out.add(node.getNumber());
out.addAll(samePartitionElementUpTree(node));
out.addAll(samePartitionElementDownTree(node));
return out.toArray(new Integer[out.size()]);
}
public int[] allTipsForThisCase(AbstractCase thisCase){
ArrayList<Integer> listOfRefs = new ArrayList<Integer>();
for(int i=0; i<getExternalNodeCount(); i++){
if(branchMap.get(i)==thisCase){
listOfRefs.add(i);
}
}
int[] out = new int[listOfRefs.size()];
for(int i=0; i<out.length; i++){out[i] = listOfRefs.get(i);}
return out;
}
public NodeRef getEarliestNodeInElement(AbstractCase thisCase){
if(thisCase.wasEverInfected()) {
NodeRef tipMRCA = caseMRCA(thisCase);
if(branchMap.get(tipMRCA.getNumber())!=thisCase){
throw new BadPartitionException("Node partition element disconnected");
}
NodeRef child = tipMRCA;
NodeRef parent = getParent(child);
boolean transmissionFound = parent == null;
while (!transmissionFound) {
if (branchMap.get(child.getNumber()) != branchMap.get(parent.getNumber())) {
transmissionFound = true;
} else {
child = parent;
parent = getParent(child);
if (parent == null) {
transmissionFound = true;
}
}
}
return child;
}
return null;
}
public HashSet<AbstractCase> getDescendants(AbstractCase thisCase){
HashSet<AbstractCase> out = new HashSet<AbstractCase>(getInfectees(thisCase));
if(thisCase.wasEverInfected()) {
for (AbstractCase child : out) {
out.addAll(getDescendants(child));
}
}
return out;
}
/* Return the case that infected this case */
public AbstractCase getInfector(AbstractCase thisCase){
if(thisCase.wasEverInfected()) {
NodeRef tipMRCA = caseMRCA(thisCase);
if(branchMap.get(tipMRCA.getNumber())!=thisCase){
throw new BadPartitionException("Node partition element disconnected");
}
NodeRef currentNode = tipMRCA;
while(branchMap.get(currentNode.getNumber())==thisCase){
currentNode = getParent(currentNode);
if(currentNode==null){
return null;
}
}
return branchMap.get(currentNode.getNumber());
}
return null;
}
public AbstractCase getRootCase(){
return branchMap.get(getRoot().getNumber());
}
public HashSet<AbstractCase> getInfectees(AbstractCase thisCase){
if(thisCase.wasEverInfected()) {
return getInfecteesInClade(getEarliestNodeInElement(thisCase));
}
return new HashSet<AbstractCase>();
}
public HashSet<AbstractCase> getInfecteesInClade(NodeRef node){
HashSet<AbstractCase> out = new HashSet<AbstractCase>();
if(isExternal(node)){
return out;
} else {
AbstractCase thisCase = branchMap.get(node.getNumber());
for(int i=0; i<getChildCount(node); i++){
NodeRef child = getChild(node, i);
AbstractCase childCase = branchMap.get(child.getNumber());
if(childCase!=thisCase){
out.add(childCase);
} else {
out.addAll(getInfecteesInClade(child));
}
}
return out;
}
}
//infector of the case assigned to this node
public AbstractCase getInfector(NodeRef node){
if(isRoot(node) || node.getNumber() == getRoot().getNumber()){
return null;
} else {
AbstractCase nodeCase = branchMap.get(node.getNumber());
if(branchMap.get(getParent(node).getNumber())!=nodeCase){
return branchMap.get(getParent(node).getNumber());
} else {
return getInfector(getParent(node));
}
}
}
/* Return the partition element of the parent of this node */
public AbstractCase getParentCase(NodeRef node){
return branchMap.get(getParent(node).getNumber());
}
public int getElementCount(){
return elementCount;
}
//Counts the children of the current node which are in the same partition element as itself
public int countChildrenInSameElement(NodeRef node){
if(isExternal(node)){
return -1;
} else {
int count = 0;
AbstractCase parentCase = branchMap.get(node.getNumber());
for(int i=0; i< getChildCount(node); i++){
if(branchMap.get(getChild(node,i).getNumber())==parentCase){
count++;
}
}
return count;
}
}
public static NodeRef sibling(TreeModel tree, NodeRef node){
if(tree.isRoot(node)){
return null;
} else {
NodeRef parent = tree.getParent(node);
for(int i=0; i<tree.getChildCount(parent); i++){
if(tree.getChild(parent,i)!=node){
return tree.getChild(parent,i);
}
}
}
return null;
}
public NodeRef caseMRCA(AbstractCase aCase, boolean checkConnectedness){
int[] caseTips = allTipsForThisCase(aCase);
NodeRef mrca = TreeUtils.getCommonAncestor(this, caseTips);
if(checkConnectedness) {
if (branchMap.get(mrca.getNumber()) != aCase) {
throw new BadPartitionException("A partition element is disconnected");
}
}
return mrca;
}
public NodeRef caseMRCA(AbstractCase aCase){
return caseMRCA(aCase, true);
}
private HashSet<NodeRef> getDescendantTips(NodeRef node){
HashSet<NodeRef> out = new HashSet<NodeRef>();
if(isExternal(node)){
out.add(node);
return out;
} else {
out.addAll(getDescendantTips(getChild(node, 0)));
out.addAll(getDescendantTips(getChild(node, 1)));
}
return out;
}
public boolean isAncestral(NodeRef node){
AbstractCase currentCase = branchMap.get(node.getNumber());
for(NodeRef tip : getDescendantTips(node)){
if(branchMap.get(tip.getNumber())==currentCase){
return true;
}
}
return false;
}
public boolean isRootBlockedBy(AbstractCase aCase, AbstractCase potentialBlocker){
return directDescendant(caseMRCA(aCase), caseMRCA(potentialBlocker));
}
public boolean isRootBlocked(AbstractCase aCase){
for(AbstractCase anotherCase : outbreak.getCases()){
if(anotherCase.wasEverInfected && anotherCase!=aCase){
if(isRootBlockedBy(aCase, anotherCase)){
return true;
}
}
}
return false;
}
private HashSet<NodeRef> getTipsInThisPartitionElement(AbstractCase aCase){
HashSet<NodeRef> out = new HashSet<NodeRef>();
// todo check that external nodes come first
for(int i=0; i<getExternalNodeCount(); i++){
if(branchMap.get(i)==aCase){
out.add(getExternalNode(i));
}
}
return out;
}
private boolean directDescendant(NodeRef node, NodeRef possibleAncestor){
NodeRef currentNode = node;
while(currentNode!=null){
if(currentNode==possibleAncestor){
return true;
}
currentNode = getParent(currentNode);
}
return false;
}
private boolean directRelationship(NodeRef node1, NodeRef node2){
return directDescendant(node1, node2) || directDescendant(node2, node1);
}
/* Populates the branch map for external nodes */
private AbstractCase[] prepareExternalNodeMap(AbstractCase[] map){
for(int i=0; i< getExternalNodeCount(); i++){
TreeModel.Node currentExternalNode = (TreeModel.Node) getExternalNode(i);
Taxon currentTaxon = currentExternalNode.taxon;
for(AbstractCase thisCase : outbreak.getCases()){
if(thisCase.wasEverInfected()) {
for (Taxon caseTaxon : thisCase.getAssociatedTaxa()) {
if (caseTaxon.equals(currentTaxon)) {
map[currentExternalNode.getNumber()] = thisCase;
}
}
}
}
}
return map;
}
/* The CSV file should have a header, and then lines matching each case to its infector*/
private void partitionAccordingToSpecificTT(String networkFileName){
System.out.println("Using specified starting transmission tree.");
try{
BufferedReader reader = new BufferedReader (new FileReader(networkFileName));
HashMap<AbstractCase, AbstractCase> specificParentMap = new HashMap<AbstractCase, AbstractCase>();
// skip header line
reader.readLine();
String currentLine = reader.readLine();
while(currentLine!=null){
currentLine = currentLine.replace("\"", "");
String[] splitLine = currentLine.split("\\,");
if(!splitLine[1].equals("Start")){
specificParentMap.put(outbreak.getCase(splitLine[0]), outbreak.getCase(splitLine[1]));
} else {
specificParentMap.put(outbreak.getCase(splitLine[0]), null);
}
currentLine = reader.readLine();
}
reader.close();
partitionAccordingToSpecificTT(specificParentMap);
} catch(IOException e){
throw new RuntimeException("Cannot read file: " + networkFileName );
}
}
private void partitionAccordingToSpecificTT(HashMap<AbstractCase, AbstractCase> map){
branchMap.setAll(prepareExternalNodeMap(new AbstractCase[getNodeCount()]), true);
//various sanity checks
for(AbstractCase aCase : map.keySet()){
if(!aCase.wasEverInfected){
throw new RuntimeException("This starting transmission tree involves never-infected cases");
}
}
AbstractCase firstCase=null;
int indexCaseCount = 0;
for(AbstractCase aCase : outbreak.getCases()){
if(aCase.wasEverInfected()) {
if (map.get(aCase) == null) {
firstCase = aCase;
indexCaseCount++;
}
}
}
if(indexCaseCount==0){
throw new RuntimeException("Given starting transmission tree appears to have a cycle");
}
if(indexCaseCount>1){
throw new RuntimeException("Given starting transmission tree appears not to be connected");
}
NodeRef root = getRoot();
specificallyPartitionDownwards(root, firstCase, map);
if(!checkPartitions()){
throw new RuntimeException("Given starting transmission tree is not compatible with the starting tree");
}
}
private void specificallyPartitionDownwards(NodeRef node, AbstractCase thisCase,
HashMap<AbstractCase, AbstractCase> map){
if(isExternal(node)){
return;
}
branchMap.set(node.getNumber(), thisCase, true);
if(isAncestral(node)){
for(int i=0; i<getChildCount(node); i++){
specificallyPartitionDownwards(getChild(node, i), thisCase, map);
}
} else {
branchMap.set(node.getNumber(), null, true);
HashSet<AbstractCase> children = new HashSet<AbstractCase>();
for(AbstractCase aCase : outbreak.getCases()){
if(map.get(aCase)==thisCase){
children.add(aCase);
}
}
HashSet<AbstractCase> relevantChildren = new HashSet<AbstractCase>(children);
for(AbstractCase child: children){
NodeRef caseMRCA = caseMRCA(child);
//either ALL the tips need to be a descendant of this node, or none. Otherwise not compatible.
if(directDescendant(node, caseMRCA)){
throw new RuntimeException("Starting transmission tree is incompatible with starting phylogeny");
}
if(caseMRCA==node){
//I'm afraid I must insist...
relevantChildren = new HashSet<AbstractCase>();
relevantChildren.add(child);
break;
}
NodeRef currentNode = caseMRCA;
while(currentNode!=node && currentNode!=null){
currentNode = getParent(currentNode);
}
if(currentNode==null){
relevantChildren.remove(child);
}
}
if(relevantChildren.size()==1){
//this ends an infection branch
AbstractCase child = relevantChildren.iterator().next();
branchMap.set(node.getNumber(), child, true);
} else {
//this can't end an infection branch
branchMap.set(node.getNumber(), thisCase, true);
}
for(int i=0; i<getChildCount(node); i++){
specificallyPartitionDownwards(getChild(node, i), branchMap.get(node.getNumber()), map);
}
}
}
/*
todo - The trouble with initialising this without the likelihood class is that lots of starting trees might
todo - fail. Need to think about how best to deal with this.
Generally allowCreep is a bad idea, since it tends to place infections after tip times and tip times
are frequently noninfectiousness times. Might be useful for some pathogens, however.
*/
private void partitionAccordingToRandomTT(boolean allowCreep){
System.out.println("Generating a random starting partition of the tree");
branchMap.setAll(prepareExternalNodeMap(new AbstractCase[getNodeCount()]), true);
NodeRef root = getRoot();
randomlyAssignNode(root, allowCreep);
}
private AbstractCase randomlyAssignNode(NodeRef node, boolean allowCreep){
if(isExternal(node)){
return branchMap.get(node.getNumber());
} else {
//If this is a descendant of a case MRCA and an ancestor of one of that case's tips, it must be
//assigned that case. If it is that of two cases then this tree is incompatible
ArrayList<AbstractCase> forcedByTopology = new ArrayList<AbstractCase>();
for(AbstractCase aCase : outbreak.getCases()){
if(aCase.wasEverInfected) {
NodeRef caseMRCA = caseMRCA(aCase, false);
HashSet<NodeRef> caseTips = getTipsInThisPartitionElement(aCase);
for (NodeRef caseTip : caseTips) {
if (directDescendant(node, caseMRCA) && directDescendant(caseTip, node)) {
if(!forcedByTopology.contains(aCase)) {
forcedByTopology.add(aCase);
}
}
}
}
}
if(forcedByTopology.size()>1){
throw new RuntimeException("Starting phylogeny is incompatible with this tip partition");
} else if(forcedByTopology.size()==1){
branchMap.set(node.getNumber(), forcedByTopology.get(0), true);
for (int i = 0; i < getChildCount(node); i++) {
if(!isExternal(getChild(node, i))){
randomlyAssignNode(getChild(node, i), allowCreep);
}
}
return forcedByTopology.get(0);
} else {
//not mandated by the topology
//three choices - case of child 1, case of child 2, case of parent, unless this is the root
AbstractCase[] choices = new AbstractCase[2];
for (int i = 0; i < getChildCount(node); i++) {
if(!isExternal(getChild(node, i))){
choices[i] = randomlyAssignNode(getChild(node, i), allowCreep);
} else {
choices[i] = branchMap.get(getChild(node,i).getNumber());
}
}
//if both choices are null and we're at the root, try again
while(isRoot(node) && choices[0]==null && choices[1]==null){
for (int i = 0; i < getChildCount(node); i++) {
if(!isExternal(getChild(node, i))){
choices[i] = randomlyAssignNode(getChild(node, i), allowCreep);
} else {
choices[i] = branchMap.get(getChild(node,i).getNumber());
}
}
}
int randomSelection;
if (isRoot(node)) {
//must make a choice at this point
randomSelection = MathUtils.nextInt(2);
//they can't both be null
if(choices[randomSelection]==null){
randomSelection = 1-randomSelection;
}
AbstractCase winner = choices[randomSelection];
fillDownTree(node, winner);
return winner;
} else {
randomSelection = MathUtils.nextInt(allowCreep ? 3 : 2);
}
if (randomSelection != 2) {
AbstractCase winner = choices[randomSelection];
AbstractCase loser = choices[1-randomSelection];
// check that this isn't going to cause a timings problem
if(getNodeHeight(getChild(node, randomSelection)) >
loser.getInfectionBranchPosition().getParameterValue(0)
*getBranchLength(getChild(node, 1-randomSelection))
+ getNodeHeight(getChild(node, 1-randomSelection))) {
winner = loser;
}
if(winner!=null) {
fillDownTree(node, winner);
} else {
branchMap.set(node.getNumber(), null, true);
}
return winner;
} else {
//parent partition element will creep to here, but we don't know what that is yet
return null;
}
}
}
}
private void fillDownTree(NodeRef node, AbstractCase aCase){
if(branchMap.get(node.getNumber())==null){
branchMap.set(node.getNumber(), aCase, true);
for(int i=0; i<2; i++){
fillDownTree(getChild(node, i), aCase);
}
}
}
}