/*
* WithinCaseCoalescent.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.epidemiology.casetocase;
import dr.app.tools.NexusExporter;
import dr.evolution.coalescent.*;
import dr.evolution.tree.FlexibleNode;
import dr.evolution.tree.FlexibleTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.*;
import dr.evomodel.coalescent.DemographicModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.loggers.LogColumn;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.math.*;
import dr.xml.*;
import java.io.IOException;
import java.io.PrintStream;
import java.math.BigDecimal;
import java.util.*;
/**
* Intended to replace the tree prior; each subtree in the partition is considered a tree in its own right generated by
* a coalescent process
*
* @author Matthew Hall
*/
public class WithinCaseCoalescent extends CaseToCaseTreeLikelihood {
public static final String WITHIN_CASE_COALESCENT = "withinCaseCoalescent";
private enum Mode {TRUNCATE, NORMAL}
private double[] partitionTreeLogLikelihoods;
private double[] storedPartitionTreeLogLikelihoods;
private boolean[] recalculateCoalescentFlags;
private DemographicModel demoModel;
private Mode mode;
private double coalescencesLogLikelihood;
private double storedCoalescencesLogLikelihood;
private boolean pleaseReExplode = true;
public WithinCaseCoalescent(PartitionedTreeModel virusTree, AbstractOutbreak caseData, Parameter maxFirstInfToRoot,
DemographicModel demoModel, Mode mode)
throws TaxonList.MissingTaxonException {
super(WITHIN_CASE_COALESCENT, virusTree, caseData, maxFirstInfToRoot);
this.mode = mode;
this.demoModel = demoModel;
addModel(demoModel);
addModel(outbreak);
partitionTreeLogLikelihoods = new double[outbreak.getCases().size()];
storedPartitionTreeLogLikelihoods = new double[outbreak.getCases().size()];
recalculateCoalescentFlags = new boolean[outbreak.getCases().size()];
Arrays.fill(recalculateCoalescentFlags, true);
elementsAsTrees = new HashMap<AbstractCase, Treelet>();
for(AbstractCase aCase: outbreak.getCases()){
if(aCase.wasEverInfected()){
elementsAsTrees.put(aCase, null);
}
}
storedElementsAsTrees = new HashMap<AbstractCase, Treelet>();
}
protected double calculateLogLikelihood(){
//checkPartitions();
if(pleaseReExplode){
explodeTree();
}
double logL = 0;
coalescencesLogLikelihood = 0;
for(AbstractCase aCase : outbreak.getCases()){
int number = outbreak.getCaseIndex(aCase);
if(aCase.wasEverInfected()) {
// and then the little tree calculations
if (recalculateCoalescentFlags[number]) {
Treelet treelet = elementsAsTrees.get(aCase);
if (treelet.getExternalNodeCount() > 1) {
SpecifiedZeroCoalescent coalescent = new SpecifiedZeroCoalescent(treelet, demoModel,
treelet.getZeroHeight(), mode == Mode.TRUNCATE);
partitionTreeLogLikelihoods[number] = coalescent.calculateLogLikelihood();
coalescencesLogLikelihood += partitionTreeLogLikelihoods[number];
} else {
partitionTreeLogLikelihoods[number] = 0.0;
}
recalculateCoalescentFlags[number] = false;
} else {
coalescencesLogLikelihood += partitionTreeLogLikelihoods[number];
}
} else {
recalculateCoalescentFlags[number] = false;
}
}
logL += coalescencesLogLikelihood;
likelihoodKnown = true;
return logL;
}
public void storeState(){
super.storeState();
storedElementsAsTrees = new HashMap<AbstractCase, Treelet>(elementsAsTrees);
storedPartitionTreeLogLikelihoods = Arrays.copyOf(partitionTreeLogLikelihoods,
partitionTreeLogLikelihoods.length);
storedCoalescencesLogLikelihood = coalescencesLogLikelihood;
}
public void restoreState(){
super.restoreState();
elementsAsTrees = storedElementsAsTrees;
partitionTreeLogLikelihoods = storedPartitionTreeLogLikelihoods;
coalescencesLogLikelihood = storedCoalescencesLogLikelihood;
}
protected void handleModelChangedEvent(Model model, Object object, int index) {
super.handleModelChangedEvent(model, object, index);
if(model == treeModel){
if(object instanceof PartitionedTreeModel.PartitionsChangedEvent){
HashSet<AbstractCase> changedPartitions =
((PartitionedTreeModel.PartitionsChangedEvent)object).getCasesToRecalculate();
for(AbstractCase aCase : changedPartitions){
recalculateCaseWCC(aCase);
}
}
} else if(model == getBranchMap()){
if(object instanceof ArrayList){
for(int i=0; i<((ArrayList) object).size(); i++){
BranchMapModel.BranchMapChangedEvent event
= (BranchMapModel.BranchMapChangedEvent)((ArrayList) object).get(i);
recalculateCaseWCC(event.getOldCase());
recalculateCaseWCC(event.getNewCase());
NodeRef node = treeModel.getNode(event.getNodeToRecalculate());
NodeRef parent = treeModel.getParent(node);
if(parent!=null){
recalculateCaseWCC(getBranchMap().get(parent.getNumber()));
}
}
} else {
throw new RuntimeException("Unanticipated model changed event from BranchMapModel");
}
} else if(model == demoModel){
Arrays.fill(recalculateCoalescentFlags, true);
} else if(model == outbreak){
if(object instanceof AbstractCase){
AbstractCase thisCase = (AbstractCase)object;
recalculateCaseWCC(thisCase);
AbstractCase parent = ((PartitionedTreeModel)treeModel).getInfector(thisCase);
if(parent!=null){
recalculateCaseWCC(parent);
}
}
}
}
protected void recalculateCaseWCC(int index){
elementsAsTrees.put(outbreak.getCase(index), null);
pleaseReExplode = true;
recalculateCoalescentFlags[index] = true;
}
protected void recalculateCaseWCC(AbstractCase aCase){
if(aCase.wasEverInfected()) {
recalculateCaseWCC(outbreak.getCaseIndex(aCase));
}
}
public void makeDirty(){
super.makeDirty();
Arrays.fill(recalculateCoalescentFlags, true);
for(AbstractCase aCase : outbreak.getCases()){
if(aCase.wasEverInfected()) {
elementsAsTrees.put(aCase, null);
}
}
pleaseReExplode = true;
}
// Tears the tree into small pieces. Indexes correspond to indexes in the outbreak.
public ArrayList<AbstractCase> postOrderTransmissionTreeTraversal(){
return traverseTransmissionTree(getBranchMap().get(treeModel.getRoot().getNumber()));
}
private ArrayList<AbstractCase> traverseTransmissionTree(AbstractCase aCase){
ArrayList<AbstractCase> out = new ArrayList<AbstractCase>();
HashSet<AbstractCase> children = ((PartitionedTreeModel)treeModel).getInfectees(aCase);
for(int i=0; i<getOutbreak().size(); i++){
AbstractCase possibleChild = getOutbreak().getCase(i);
// easiest way to maintain the set ordering of the outbreak?
if(children.contains(possibleChild)){
out.addAll(traverseTransmissionTree(possibleChild));
}
}
out.add(aCase);
return out;
}
private Treelet transformTreelet(Treelet treelet){
double[] transformedNodeTimes = new double[treelet.getNodeCount()];
double totalHeight = treelet.getZeroHeight();
double willMapToZero = totalHeight - 1;
for(int i=0; i<treelet.getNodeCount(); i++){
NodeRef node = treelet.getNode(i);
double time = treelet.getNodeHeight(node) - totalHeight;
transformedNodeTimes[i] = -Math.log(-(time));
}
double first = Double.POSITIVE_INFINITY;
for (double transformedNodeTime : transformedNodeTimes) {
if (transformedNodeTime < first) {
first = transformedNodeTime;
}
}
double zeroHeight = -first;
Treelet copy = new Treelet(treelet, zeroHeight);
for(int i=0; i<copy.getNodeCount(); i++){
NodeRef node = copy.getNode(i);
copy.setNodeHeight(node, transformedNodeTimes[i] - first);
}
copy.resolveTree();
return copy;
}
private class SpecifiedZeroCoalescent extends Coalescent {
private double zeroHeight;
boolean truncate;
private SpecifiedZeroCoalescent(Tree tree, DemographicModel demographicModel, double zeroHeight,
boolean truncate){
super(tree, demographicModel.getDemographicFunction());
this.zeroHeight = zeroHeight;
this.truncate = truncate;
}
public double calculateLogLikelihood() {
return calculatePartitionTreeLogLikelihood(getIntervals(), getDemographicFunction(), 0, zeroHeight,
truncate);
}
}
public static double calculatePartitionTreeLogLikelihood(IntervalList intervals,
DemographicFunction demographicFunction, double threshold,
double zeroHeight, boolean truncate) {
double logL = 0.0;
double startTime = -zeroHeight;
final int n = intervals.getIntervalCount();
//TreeIntervals sets up a first zero-length interval with a lineage count of zero - skip this one
for (int i = 0; i < n; i++) {
if(truncate) {
// time zero corresponds to the date of first infection
final double duration = intervals.getInterval(i);
final double finishTime = startTime + duration;
// if this has happened the run is probably pretty unhappy
if (finishTime == 0) {
return Double.NEGATIVE_INFINITY;
}
final double intervalArea = demographicFunction.getIntegral(startTime, finishTime);
final double normalisationArea = demographicFunction.getIntegral(startTime, 0);
if (intervalArea == 0 && duration > tolerance) {
return Double.NEGATIVE_INFINITY;
}
final int lineageCount = intervals.getLineageCount(i);
if (lineageCount >= 2) {
final double kChoose2 = Binomial.choose2(lineageCount);
if (intervals.getIntervalType(i) == IntervalType.COALESCENT) {
logL += -kChoose2 * intervalArea;
final double demographicAtCoalPoint = demographicFunction.getDemographic(finishTime);
if (duration == 0.0 || demographicAtCoalPoint * (intervalArea / duration) >= threshold) {
logL -= Math.log(demographicAtCoalPoint);
} else {
return Double.NEGATIVE_INFINITY;
}
} else {
double numerator = Math.exp(-kChoose2 * intervalArea) - Math.exp(-kChoose2 * normalisationArea);
logL += Math.log(numerator);
}
// normalisation
double normExp = Math.exp(-kChoose2 * normalisationArea);
double logDenominator;
// the denominator has an irritating tendency to round to zero
if (normExp != 1) {
logDenominator = Math.log1p(-normExp);
} else {
logDenominator = handleDenominatorUnderflow(-kChoose2 * normalisationArea);
}
logL -= logDenominator;
}
startTime = finishTime;
} else {
if(!(demographicFunction instanceof LinearGrowth)){
throw new RuntimeException("Function must have zero population at t=0 if truncate=false");
}
final double duration = intervals.getInterval(i);
final double finishTime = startTime + duration;
final double intervalArea = demographicFunction.getIntegral(startTime, finishTime);
if( intervalArea == 0 && duration != 0 ) {
return Double.NEGATIVE_INFINITY;
}
final int lineageCount = intervals.getLineageCount(i);
final double kChoose2 = Binomial.choose2(lineageCount);
// common part
logL += -kChoose2 * intervalArea;
if (intervals.getIntervalType(i) == IntervalType.COALESCENT) {
final double demographicAtCoalPoint = demographicFunction.getDemographic(finishTime);
// if value at end is many orders of magnitude different than mean over interval reject the interval
// This is protection against cases where ridiculous infinitesimal population size at the end of a
// linear interval drive coalescent values to infinity.
if( duration == 0.0 || demographicAtCoalPoint * (intervalArea/duration) >= threshold ) {
// if( duration == 0.0 || demographicAtCoalPoint >= threshold * (duration/intervalArea) ) {
logL -= Math.log(demographicAtCoalPoint);
} else {
// remove this at some stage
// System.err.println("Warning: " + i + " " + demographicAtCoalPoint + " " + (intervalArea/duration) );
return Double.NEGATIVE_INFINITY;
}
}
startTime = finishTime;
}
}
return logL;
}
private static double handleDenominatorUnderflow(double input){
BigDecimal bigDec = new BigDecimal(input);
BigDecimal expBigDec = BigDecimalUtils.exp(bigDec, bigDec.scale());
BigDecimal one = new BigDecimal(1.0);
BigDecimal oneMinusExpBigDec = one.subtract(expBigDec);
BigDecimal logOneMinusExpBigDec = BigDecimalUtils.ln(oneMinusExpBigDec, oneMinusExpBigDec.scale());
return logOneMinusExpBigDec.doubleValue();
}
public void debugTreelet(Tree treelet, String fileName){
try{
FlexibleTree treeCopy = new FlexibleTree(treelet);
for(int j=0; j<treeCopy.getNodeCount(); j++){
FlexibleNode node = (FlexibleNode)treeCopy.getNode(j);
node.setAttribute("Number", node.getNumber());
}
NexusExporter testTreesOut = new NexusExporter(new PrintStream(fileName));
testTreesOut.exportTree(treeCopy);
} catch (IOException ignored) {System.out.println("IOException");}
}
public LogColumn[] passColumns(){
ArrayList<LogColumn> columns = new ArrayList<LogColumn>(Arrays.asList(super.passColumns()));
if(outbreak instanceof CategoryOutbreak) {
for (int i = 0; i < outbreak.size(); i++) {
if(outbreak.getCase(i).wasEverInfected()) {
final int finalI = i;
columns.add(new LogColumn.Abstract("coal_LL_" + i) {
protected String getFormattedValue() {
return String.valueOf(partitionTreeLogLikelihoods[finalI]);
}
});
}
}
columns.add(new LogColumn.Abstract("total_coal_LL") {
protected String getFormattedValue() {
return String.valueOf(coalescencesLogLikelihood);
}
});
return columns.toArray(new LogColumn[columns.size()]);
}
return null;
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public static final String STARTING_NETWORK = "startingNetwork";
public static final String MAX_FIRST_INF_TO_ROOT = "maxFirstInfToRoot";
public static final String DEMOGRAPHIC_MODEL = "demographicModel";
public static final String TRUNCATE = "truncate";
public String getParserName() {
return WITHIN_CASE_COALESCENT;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
PartitionedTreeModel virusTree = (PartitionedTreeModel) xo.getChild(TreeModel.class);
AbstractOutbreak caseSet = (AbstractOutbreak) xo.getChild(AbstractOutbreak.class);
CaseToCaseTreeLikelihood likelihood;
Parameter earliestFirstInfection = (Parameter) xo.getElementFirstChild(MAX_FIRST_INF_TO_ROOT);
DemographicModel demoModel = (DemographicModel) xo.getElementFirstChild(DEMOGRAPHIC_MODEL);
Mode mode = xo.hasAttribute(TRUNCATE) & xo.getBooleanAttribute(TRUNCATE) ? Mode.TRUNCATE : Mode.NORMAL;
try {
likelihood = new WithinCaseCoalescent(virusTree, caseSet, earliestFirstInfection, demoModel, mode);
} catch (TaxonList.MissingTaxonException e) {
throw new XMLParseException(e.toString());
}
return likelihood;
}
public String getParserDescription() {
return "This element provides a tree prior for a partitioned tree, with each partitioned tree generated" +
"by a coalescent process";
}
public Class getReturnType() {
return WithinCaseCoalescent.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
new ElementRule(PartitionedTreeModel.class, "The tree"),
new ElementRule(CategoryOutbreak.class, "The set of cases", 0,1),
new ElementRule(CategoryOutbreak.class, "The set of cases", 0,1),
new ElementRule("startingNetwork", String.class, "A CSV file containing a specified starting network",
true),
new ElementRule(MAX_FIRST_INF_TO_ROOT, Parameter.class, "The maximum time from the first infection to" +
"the root node"),
new ElementRule(DEMOGRAPHIC_MODEL, DemographicModel.class, "The demographic model for within-case" +
"evolution"),
AttributeRule.newBooleanRule(TRUNCATE)
};
};
}