/* * BirthDeathSerialSamplingModel.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.speciation; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.util.Taxon; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.util.Author; import dr.util.Citable; import dr.util.Citation; import java.util.Collections; import java.util.List; import java.util.Set; /** * Beginning of tree prior for birth-death + serial sampling + extant sample proportion. More Tanja magic... * * @author Alexei Drummond */ public class BirthDeathSerialSamplingModel extends MaskableSpeciationModel implements Citable { // R0 Variable<Double> R0; // recovery rate Variable<Double> recoveryRate; // sampling probability Variable<Double> samplingProbability; // birth rate Variable<Double> lambda; // death rate Variable<Double> mu; // serial sampling rate Variable<Double> psi; // extant sampling proportion Variable<Double> p; //boolean death rate is relative? boolean relativeDeath = false; // boolean stating whether sampled individuals remain infectious, or become non-infectious // boolean sampledIndividualsRemainInfectious = false; // replaced by r // the additional parameter 0 <= r <= 1 has to be estimated. // for r=1, this is sampledRemainInfectiousProb=0 // for r=0, this is sampledRemainInfectiousProb=1 Variable<Double> r; //Variable<Double> finalTimeInterval; boolean hasFinalSample = false; // the origin of the infection, x0 > tree.getRoot(); Variable<Double> origin; public BirthDeathSerialSamplingModel( Variable<Double> lambda, Variable<Double> mu, Variable<Double> psi, Variable<Double> p, boolean relativeDeath, Variable<Double> r, boolean hasFinalSample, Variable<Double> origin, Type units) { this("birthDeathSerialSamplingModel", lambda, mu, psi, p, relativeDeath, r, hasFinalSample, origin, units); } public BirthDeathSerialSamplingModel( String modelName, Variable<Double> lambda, Variable<Double> mu, Variable<Double> psi, Variable<Double> p, boolean relativeDeath, Variable<Double> r, boolean hasFinalSample, Variable<Double> origin, Type units) { super(modelName, units); this.relativeDeath = relativeDeath; this.lambda = lambda; addVariable(lambda); lambda.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); this.mu = mu; addVariable(mu); mu.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); this.psi = psi; addVariable(psi); psi.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); this.p = p; addVariable(p); p.addBounds(new Parameter.DefaultBounds(1.0, 0.0, 1)); this.hasFinalSample = hasFinalSample; this.r = r; addVariable(r); r.addBounds(new Parameter.DefaultBounds(1.0, 0.0, 1)); this.origin = origin; if (origin != null) { addVariable(origin); origin.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); } } public BirthDeathSerialSamplingModel( String modelName, Variable<Double> R0, Variable<Double> recoveryRate, Variable<Double> samplingProbability, Variable<Double> origin, Type units) { super(modelName, units); this.relativeDeath = false; this.hasFinalSample = false; this.R0 = R0; addVariable(R0); R0.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); this.recoveryRate = recoveryRate; addVariable(recoveryRate); recoveryRate.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); this.samplingProbability = samplingProbability; addVariable(samplingProbability); samplingProbability.addBounds(new Parameter.DefaultBounds(1.0, 0.0, 1)); this.origin = origin; if (origin != null) { addVariable(origin); origin.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); } } /** * @param b birth rate * @param d death rate * @param p proportion sampled at final time point * @param psi rate of sampling per lineage per unit time * @param t time * @return the probability of no sampled descendants after time, t */ public static double p0(double b, double d, double p, double psi, double t) { double c1 = c1(b, d, psi); double c2 = c2(b, d, p, psi); double expc1trc2 = Math.exp(-c1 * t) * (1.0 - c2); return (b + d + psi + c1 * ((expc1trc2 - (1.0 + c2)) / (expc1trc2 + (1.0 + c2)))) / (2.0 * b); } public static double q(double b, double d, double p, double psi, double t) { double c1 = c1(b, d, psi); double c2 = c2(b, d, p, psi); // double res = 2.0 * (1.0 - c2 * c2) + Math.exp(-c1 * t) * (1.0 - c2) * (1.0 - c2) + Math.exp(c1 * t) * (1.0 + c2) * (1.0 + c2); double res = c1 * t + 2.0 * Math.log( Math.exp(-c1 * t) * (1.0 - c2) + (1.0 + c2) ); // operate directly in logspace, c1 * t too big return res; } private static double c1(double b, double d, double psi) { return Math.abs(Math.sqrt(Math.pow(b - d - psi, 2.0) + 4.0 * b * psi)); } private static double c2(double b, double d, double p, double psi) { return -(b - d - 2.0 * b * p - psi) / c1(b, d, psi); } public double p0(double t) { return p0(birth(), death(), p(), psi(), t); } public double q(double t) { return q(birth(), death(), p(), psi(), t); } private double c1() { return c1(birth(), death(), psi()); } private double c2() { return c2(birth(), death(), p(), psi()); } public double birth() { if (mask != null) return mask.birth(); if (lambda != null) { return lambda.getValue(0); } else { double r0 = R0.getValue(0); double rr = recoveryRate.getValue(0); return r0 * rr; } } public double death() { if (mask != null) return mask.death(); if (mu != null) { return relativeDeath ? mu.getValue(0) * birth() : mu.getValue(0); } else { double rr = recoveryRate.getValue(0); double sp = samplingProbability.getValue(0); return rr * (1.0 - sp); } } public double psi() { if (mask != null) return mask.psi(); if (psi != null) { return psi.getValue(0); } else { double rr = recoveryRate.getValue(0); double sp = samplingProbability.getValue(0); return rr * sp; } } /** * @return the proportion of population sampled at final sample, or zero if there is no final sample */ public double p() { if (mask != null) return mask.p.getValue(0); return hasFinalSample ? p.getValue(0) : 0; } // The mask does not affect the following three methods public boolean isSamplingOrigin() { return origin != null; } public double x0() { return origin.getValue(0); } /** * Generic likelihood calculation * * @param tree the tree to calculate likelihood of * @return log-likelihood of density */ public final double calculateTreeLogLikelihood(Tree tree) { if (isSamplingOrigin() && x0() < tree.getNodeHeight(tree.getRoot())) { return Double.NEGATIVE_INFINITY; // throw new RuntimeException("Orign value (" + x0() + ") cannot < tree root height (" + tree.getNodeHeight(tree.getRoot()) + ")"); } //System.out.println("calculating tree log likelihood"); //double time = finalTimeInterval(); // extant leaves int n = 0; // extinct leaves int m = 0; for (int i = 0; i < tree.getExternalNodeCount(); i++) { NodeRef node = tree.getExternalNode(i); if (tree.getNodeHeight(node) == 0.0) { n += 1; } else { m += 1; } } if (!hasFinalSample && n < 1) { throw new RuntimeException( "For sampling-through-time model there must be at least one tip at time zero."); } double b = birth(); double p = p(); double logL; if (isSamplingOrigin()) { // logL = Math.log(1.0 / q(x0())); logL = - q(x0()); //System.out.println("originLogL=" + logL + " x0"); } else { throw new RuntimeException( "The origin must be sampled, as integrating it out is not implemented!"); // integrating out the time between the origin and the root of the tree //double bottom = c1 * (c2 + 1) * (1 - c2 + (1 + c2) * Math.exp(c1 * x1)); //logL = Math.log(1 / bottom); } if (hasFinalSample) { logL += n * Math.log(4.0 * p); } for (int i = 0; i < tree.getInternalNodeCount(); i++) { double x = tree.getNodeHeight(tree.getInternalNode(i)); logL += Math.log(b) - q(x); //System.out.println("internalNodeLogL=" + Math.log(b / q(x))); } for (int i = 0; i < tree.getExternalNodeCount(); i++) { double y = tree.getNodeHeight(tree.getExternalNode(i)); if (y > 0.0) { logL += Math.log(psi()) + q(y); //System.out.println("externalNodeLogL=" + Math.log(psi() * (r() + (1.0 - r()) * p0(y)) * q(y))); } else if (!hasFinalSample) { //handle condition ending on final tip in sampling-through-time-only situation logL += Math.log(psi()) + q(y); // System.out.println("externalNodeLogL=" + Math.log(psi() * q(y))); } } return logL; } public double calculateTreeLogLikelihood(Tree tree, Set<Taxon> exclude) { if (exclude.size() == 0) return calculateTreeLogLikelihood(tree); throw new RuntimeException("Not implemented!"); } public void mask(SpeciationModel mask) { if (mask instanceof BirthDeathSerialSamplingModel) { this.mask = (BirthDeathSerialSamplingModel) mask; } else { throw new IllegalArgumentException(); } } public void unmask() { mask = null; } // if a mask exists then use the mask's parameters instead (except for origin and finalTimeInterval) BirthDeathSerialSamplingModel mask = null; @Override public Citation.Category getCategory() { return Citation.Category.TREE_PRIORS; } @Override public String getDescription() { return "Gernhard 2008 Birth Death Tree Model"; } @Override public List<Citation> getCitations() { return Collections.singletonList(new Citation( new Author[]{ new Author("T", "Gernhard"), }, "The conditioned reconstructed process", 2008, "Journal of Theoretical Biology", 253, 769, 778, "10.1016/j.jtbi.2008.04.005" )); } }