/*
* BirthDeathSerialSkylineModel.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.speciation;
import dr.evolution.io.Importer;
import dr.evolution.io.NewickImporter;
import dr.evolution.tree.Tree;
import dr.evolution.util.Taxon;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import java.io.IOException;
import java.util.Arrays;
import java.util.Set;
/**
* Beginning of tree prior for birth-death + serial sampling + extant sample proportion. More Tanja magic...
* <p/>
* log:
* 25 Mar 2011, Denise: added int i (index) for the Variables that change over time such as the methods p0(..), q(..); fixed some formulas (old versions commented out)
* unclear marked with todo
*
* @author Alexei Drummond
*/
public class BirthDeathSerialSkylineModel extends SpeciationModel {
// times
Variable<Double> times;
// birth rate
Variable<Double> lambda;
// death rate
Variable<Double> mu;
// serial sampling rate
Variable<Double> psi;
// extant sampling proportion
Variable<Double> p;
Variable<Double> origin;
//boolean death rate is relative?
boolean relativeDeath = false;
// the number of intervals;
int size = 1;
double t_root;
double x0;
protected double[] p0_iMinus1;
protected double[] Ai;
protected double[] Bi;
protected boolean birthChanges = true;
protected boolean deathChanges = true;
protected boolean samplingChanges = true;
protected boolean timesStartFromOrigin = true;
protected double[] timesFromTips;
public BirthDeathSerialSkylineModel(
Variable<Double> times,
Variable<Double> lambda,
Variable<Double> mu,
Variable<Double> psi,
Variable<Double> p,
Variable<Double> origin,
boolean relativeDeath,
boolean sampledIndividualsRemainInfectious,
boolean timesStartFromOrigin,
Type units) {
this("birthDeathSerialSamplingModel", times, lambda, mu, psi, p, origin, relativeDeath,
sampledIndividualsRemainInfectious, timesStartFromOrigin, units);
}
public BirthDeathSerialSkylineModel(
String modelName,
Variable<Double> times,
Variable<Double> lambda,
Variable<Double> mu,
Variable<Double> psi,
Variable<Double> p,
Variable<Double> origin,
boolean relativeDeath,
boolean sampledIndividualsRemainInfectious,
boolean timesStartFromOrigin,
Type units) {
super(modelName, units);
this.size = times.getSize();
if (lambda.getSize() != 1 && lambda.getSize() != size)
throw new RuntimeException("Length of Lambda parameter should be one or equal to the size of time parameter (size = " + size + ")");
if (mu.getSize() != 1 && mu.getSize() != size)
throw new RuntimeException("Length of mu parameter should be one or equal to the size of time parameter (size = " + size + ")");
this.timesStartFromOrigin = timesStartFromOrigin;
this.times = times;
addVariable(times);
times.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, times.getSize()));
this.lambda = lambda;
addVariable(lambda);
lambda.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, lambda.getSize()));
this.mu = mu;
addVariable(mu);
mu.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, mu.getSize()));
this.p = p;
addVariable(p);
p.addBounds(new Parameter.DefaultBounds(1.0, 0.0, p.getSize()));
this.origin = origin;
addVariable(origin);
p.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, origin.getSize()));
this.psi = psi;
addVariable(psi);
psi.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, psi.getSize()));
this.relativeDeath = relativeDeath;
}
/**
* @param time the time
* @param tree the tree
* @return the number of lineages that exist at the given time in the given tree.
*/
public int lineageCountAtTime(double time, Tree tree) {
int count = 1;
for (int i = 0; i < tree.getInternalNodeCount(); i++) {
if (tree.getNodeHeight(tree.getInternalNode(i)) > time) count += 1;
}
for (int i = 0; i < tree.getExternalNodeCount(); i++) {
if (tree.getNodeHeight(tree.getExternalNode(i)) > time) count -= 1;
}
return count;
}
public double Ai(double b, double g, double psi) {
return Math.sqrt((b - g - psi) * (b - g - psi) + 4.0 * b * psi);
}
public double Bi(double b, double g, double psi, double A, double p0) {
return (-((1.0 - 2.0 * p0) * b + g + psi) / A);
}
public double p0(int index, double t, double ti) {
return p0(birth(birthChanges ? index : 0), death(deathChanges ? index : 0), psi(samplingChanges ? index : 0), Ai[index], Bi[index], t, ti);
}
public double p0(double b, double g, double psi, double A, double B, double t, double ti) {
return ((b + g + psi - A * ((Math.exp(A * (t - ti)) * (1.0 - B) - (1.0 + B))) / (Math.exp(A * (t - ti)) * (1.0 - B) + (1.0 + B))) / (2.0 * b));
}
public double g(int index, double t, double ti) {
double oneMinusBiSq = (1.0 - Bi[index]) * (1.0 - Bi[index]);
double onePlusBiSq = (1.0 + Bi[index]) * (1.0 + Bi[index]);
return 4.0 / (2.0 * (1.0 - Bi[index] * Bi[index]) + Math.exp(Ai[index] * (t - ti)) *
oneMinusBiSq + Math.exp(-Ai[index] * (t - ti)) * onePlusBiSq);
}
/**
* Returns the time at which epoch i begins. If
*
* @param i index of the epoch
* @return the time at which this epoch begins
*/
public double t(int i) {
return timesFromTips[i];
}
public double birth(int i) {
return lambda.getValue(i);
}
public double death(int i) {
return relativeDeath ? mu.getValue(i) * birth(i) : mu.getValue(i);
}
public double psi(int i) {
return psi.getValue(i);
}
public double p() {
return p.getValue(0);
}
/**
* @param t
* @return the birth parameter for the given time
*/
public double lambda(double t) {
return lambda.getValue(index(t));
}
/**
* @param t
* @return the mutation parameter for the given time
*/
public double mu(double t) {
return mu.getValue(index(t));
}
public int index(double t) {
int epoch = Arrays.binarySearch(timesFromTips, t);
if (epoch < 0) {
epoch = -epoch - 1;
}
return Math.max(epoch - 1, 0);
}
/* calculate and store Ai, Bi and p0_iMinus1 */
public void preCalculation(Tree tree) {
t_root = tree.getNodeHeight(tree.getRoot());
x0 = t_root + origin.getValue(0);
// set up timesFromTips array
if (timesFromTips == null) {
timesFromTips = new double[times.getSize()];
}
if (timesStartFromOrigin) {
timesFromTips[0] = 0;
for (int i = 1; i < timesFromTips.length; i++) {
timesFromTips[i] = Math.max(0, x0 - times.getValue(timesFromTips.length - i));
}
} else {
for (int i = 0; i < timesFromTips.length; i++) {
timesFromTips[i] = times.getValue(i);
}
}
Ai = new double[size];
Bi = new double[size];
p0_iMinus1 = new double[size];
for (int i = 0; i < size; i++) {
Ai[i] = Ai(birth(birthChanges ? i : 0), death(deathChanges ? i : 0), psi(samplingChanges ? i : 0));
//System.out.println("Ai[" + i + "]=" + Ai[i]);
}
Bi[0] = Bi(birth(0), death(0), psi(0), Ai[0], 1);
//System.out.println("Bi[0]=" + Bi[0]);
for (int i = 1; i < size; i++) {
p0_iMinus1[i - 1] = p0(birth(birthChanges ? (i - 1) : 0), death(deathChanges ? (i - 1) : 0), psi(samplingChanges ? (i - 1) : 0), Ai[i - 1], Bi[i - 1], t(i), t(i - 1));
Bi[i] = Bi(birth(birthChanges ? i : 0), death(deathChanges ? i : 0), psi(samplingChanges ? i : 0), Ai[i], p0_iMinus1[i - 1]);
//System.out.println("Bi[" + i + "]=" + Bi[i]);
}
}
/**
* Generic likelihood calculation
*
* @param tree the tree to calculate likelihood of
* @return log-likelihood of density
*/
public final double calculateTreeLogLikelihood(Tree tree) {
// number of lineages at each time ti
int[] n = new int[size];
int nTips = tree.getExternalNodeCount();
preCalculation(tree);
int index = size - 1; // x0 must be in last interval
double t = t(index);
double g = g(index, x0, t);
double logP = Math.log(g);
// first product term in f[T]
for (int i = 0; i < tree.getInternalNodeCount(); i++) {
double x = tree.getNodeHeight(tree.getInternalNode(i));
index = index(x);
double contrib = Math.log(birth(birthChanges ? index : 0) * g(index, x, t(index)));
logP += contrib;
//System.out.println("internalNode.logP=" + contrib);
t = t(index);
g = g(index, x, t);
//System.out.println("logP+=" + (Math.log(birth(birthChanges ? index : 0) * g)) + " t= " + t + " g=" + g);
}
// middle product term in f[T]
for (int i = 0; i < nTips; i++) {
double y = tree.getNodeHeight(tree.getExternalNode(i));
index = index(y);
double contrib = Math.log(psi(samplingChanges ? index : 0)) - Math.log(g(index, y, t(index)));
;
logP += contrib;
//System.out.println("externalNode.logP=" + contrib);
}
// last product term in f[T], factorizing from 1 to m
for (int j = 0; j < size - 1; j++) {
double contrib = 0;
double time = t(j + 1);
n[j] = lineageCountAtTime(time, tree);
if (n[j] > 0) {
contrib += n[j] * Math.log(g(j, time, t(j)));
//System.out.println("n[" + j + "]" + n[j] + " time=" + time + " t(" + j + ")=" + t(j));
}
logP += contrib;
//System.out.println("last term=" + contrib);
}
return logP;
}
public double calculateTreeLogLikelihood(Tree tree, Set<Taxon> exclude) {
if (exclude.size() == 0) return calculateTreeLogLikelihood(tree);
throw new RuntimeException("Not implemented!");
}
public static void main(String[] args) throws IOException, Importer.ImportException {
// test framework
Variable<Double> times = new Variable.D(1, 10);
Variable<Double> mu = new Variable.D(1, 10);
for (int i = 0; i < mu.getSize(); i++) {
times.setValue(i, (i + 1) * 2.0);
mu.setValue(i, i + 1.0);
}
Variable<Double> lambda = new Variable.D(1, 10);
Variable<Double> psi = new Variable.D(0.5, 1);
Variable<Double> p = new Variable.D(0.5, 1);
Variable<Double> origin = new Variable.D(0.5, 1);
boolean relativeDeath = false;
boolean sampledIndividualsRemainInfectious = false;
boolean timesStartFromOrigin = false;
BirthDeathSerialSkylineModel model =
new BirthDeathSerialSkylineModel(times, lambda, mu, psi, p, origin, relativeDeath,
sampledIndividualsRemainInfectious, timesStartFromOrigin, Type.SUBSTITUTIONS);
NewickImporter importer = new NewickImporter("((A:6,B:5):4,(C:3,D:2):1);");
Tree tree = importer.importNextTree();
model.calculateTreeLogLikelihood(tree);
for (int i = 0; i < times.getSize(); i += 1) {
System.out.println("mu at time " + i + " is " + model.mu(i));
System.out.println("p0 at time " + i + " is " + model.p0(0, i, i));
}
}
}