/*
* VDdemographicFunction.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.coalescent;
import dr.evolution.coalescent.DemographicFunction;
import dr.evolution.coalescent.IntervalType;
import dr.evolution.coalescent.TreeIntervals;
import dr.evolution.tree.Tree;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* @author Joseph Heled
*/
public class VDdemographicFunction extends DemographicFunction.Abstract {
private double[] values;
private double[] times;
private double[] intervals;
private double[][] ttimes;
private double[] alltimes;
private boolean[] dirtyTrees;
boolean dirty;
private final VariableDemographicModel.Type type;
TreeIntervals[] ti;
public VDdemographicFunction(Tree[] trees, VariableDemographicModel.Type type,
double[] indicatorParameter, double[] popSizeParameter, boolean logSpace,
boolean mid) {
super(trees[0].getUnits());
this.type = type;
ti = new TreeIntervals[trees.length];
dirtyTrees = new boolean[trees.length];
Arrays.fill(dirtyTrees, true);
ttimes = new double[ti.length][];
int tot = 0;
for (int k = 0; k < ti.length; ++k) {
ttimes[k] = new double[trees[k].getTaxonCount() - 1];
tot += ttimes[k].length;
}
alltimes = new double[tot];
setDirty();
assert !(type == VariableDemographicModel.Type.EXPONENTIAL && !logSpace);
setup(trees, indicatorParameter, popSizeParameter, logSpace, mid);
}
/**
* Reduce memory footprint of object. After a call to freeze only population/intensity
* are allowed.
*/
public void freeze() {
ttimes = null;
alltimes = null;
dirtyTrees = null;
ti = null;
}
public VDdemographicFunction(VDdemographicFunction demoFunction) {
super(demoFunction.getUnits());
type = demoFunction.type;
this.ti = demoFunction.ti.clone();
this.values = demoFunction.values.clone();
this.times = demoFunction.times.clone();
this.intervals = demoFunction.intervals.clone();
this.ttimes = demoFunction.ttimes.clone();
for (int k = 0; k < ttimes.length; ++k) {
ttimes[k] = ttimes[k].clone();
}
this.alltimes = demoFunction.alltimes.clone();
this.dirtyTrees = demoFunction.dirtyTrees.clone();
this.dirty = demoFunction.dirty;
}
// Hack so that VDdemo can be used as just a linear piecewise demography (the BEAST one is broken)
// Alexei fixed PiecewiseLinearPopulation, but did not say yet if it is tested or not.
public VDdemographicFunction(double[] t, double[] p, Type units) {
this(t, p, units, VariableDemographicModel.Type.LINEAR);
}
public VDdemographicFunction(double[] t, double[] p, Type units, VariableDemographicModel.Type type) {
super(units);
// seem safe to remove them for now
// assert t[0] >= 0;
// for(int k = 1; k < t.length; ++k) {
// assert t[k-1] <= t[k];
// }
// assert t.length + 1 == p.length;
//
this.type = type;
final int tot = p.length;
times = new double[tot + 1];
values = p;
intervals = new double[tot - 1];
times[0] = 0.0;
times[tot] = Double.POSITIVE_INFINITY;
// boolean logSpace = false;
// values[0] = logSpace ? Math.exp(p[0]) : p[0];
System.arraycopy(t, 0, times, 1, t.length);
for (int n = 0; n < intervals.length; ++n) {
intervals[n] = times[n + 1] - times[n];
}
dirty = false;
}
public int numberOfChanges() {
return values.length - 1;
}
public void treeChanged(int nt) {
dirtyTrees[nt] = true;
setDirty();
}
public void setDirty() {
dirty = true;
}
private boolean setTreeTimes(int nt, Tree[] trees) {
if (dirtyTrees[nt]) {
/*double[] doubles = null;
if( ! dirtyTrees[nt] ) {
doubles = ttimes[nt].clone();
}*/
ti[nt] = new TreeIntervals(trees[nt]);
TreeIntervals nti = ti[nt];
// make sure we get each coalescent event individually
nti.setMultifurcationLimit(0);
// code probably incorrect for serial samples
final int nLineages = nti.getIntervalCount();
assert nLineages >= ttimes[nt].length : nLineages + " " + ttimes[nt].length;
int iCount = 0;
for (int k = 0; k < ttimes[nt].length; ++k) {
double timeToCoal = nti.getInterval(iCount);
while (nti.getIntervalType(iCount) != IntervalType.COALESCENT) {
++iCount;
timeToCoal += nti.getInterval(iCount);
}
int linAtStart = nti.getLineageCount(iCount);
++iCount;
assert !(iCount == nLineages && linAtStart != 2);
int linAtEnd = (iCount == nLineages) ? 1 : nti.getLineageCount(iCount);
while (linAtStart <= linAtEnd) {
++iCount;
timeToCoal += nti.getInterval(iCount);
linAtStart = linAtEnd;
++iCount;
linAtEnd = nti.getLineageCount(iCount);
}
ttimes[nt][k] = timeToCoal + (k == 0 ? 0 : ttimes[nt][k - 1]);
}
/*if( doubles != null ) {
if( ! Arrays.equals(doubles, ttimes[nt]) ) {
System.out.println(Arrays.toString(doubles) + " != " + Arrays.toString(ttimes[nt])
+ Arrays.toString(dirtyTrees) + " " + dirtyTrees);
}
}*/
dirtyTrees[nt] = false;
// System.out.print(nt + " " + Arrays.toString(dirtyTrees) + " " + dirtyTrees);
return true;
}
return false;
}
void setup(Tree[] trees, double[] indicatorParameter, double[] popSizes, boolean logSpace, boolean mid) {
// boolean was = dirty;
if (dirty) {
// for exponential we do the exp in the code
if (type == VariableDemographicModel.Type.EXPONENTIAL) logSpace = false;
boolean any = false;
for (int nt = 0; nt < ti.length; ++nt) {
if (setTreeTimes(nt, trees)) {
any = true;
}
}
final int nd = indicatorParameter.length;
assert nd == alltimes.length + (type == VariableDemographicModel.Type.STEPWISE ? -1 : 0) :
" nd=" + nd + " alltimes.length=" + alltimes.length + " type=" + type;
if (any) {
// now we want to merge times together
int[] inds = new int[ttimes.length];
for (int k = 0; k < alltimes.length; ++k) {
int j = 0;
while (inds[j] == ttimes[j].length) {
++j;
}
for (int l = j + 1; l < inds.length; ++l) {
if (inds[l] < ttimes[l].length) {
if (ttimes[l][inds[l]] < ttimes[j][inds[j]]) {
j = l;
}
}
}
alltimes[k] = ttimes[j][inds[j]];
inds[j]++;
}
}
// assumes lowest node has time 0. this is probably problematic when we come
// to deal with multiple trees
int tot = 1;
for (int k = 0; k < nd; ++k) {
if (indicatorParameter[k] > 0) {
++tot;
}
}
times = new double[tot + 1];
values = new double[tot];
intervals = new double[tot - 1];
times[0] = 0.0;
times[tot] = Double.POSITIVE_INFINITY;
final boolean xx = type == VariableDemographicModel.Type.LINEAR && !logSpace && false;
if (xx) {
double[] a = alltimes;
if (mid) {
a = new double[alltimes.length];
for (int k = 0; k < a.length; ++k) {
a[k] = ((alltimes[k] + (k > 0 ? alltimes[k - 1] : 0)) / 2);
}
}
bestLinearFit(a, popSizes, indicatorParameter, times, values);
for (int n = 0; n < intervals.length; ++n) {
intervals[n] = times[n + 1] - times[n];
}
for (int n = 0; n < values.length; ++n) {
if (values[n] <= 0) {
values[n] = 1e-30;
}
}
}
if (!xx) {
values[0] = logSpace ? Math.exp(popSizes[0]) : popSizes[0];
int n = 0;
for (int k = 0; k < nd && n + 1 < tot; ++k) {
if (indicatorParameter[k] > 0) {
times[n + 1] = mid ? ((alltimes[k] + (k > 0 ? alltimes[k - 1] : 0)) / 2) : alltimes[k];
values[n + 1] = logSpace ? Math.exp(popSizes[k + 1]) : popSizes[k + 1];
intervals[n] = times[n + 1] - times[n];
++n;
}
}
}
dirty = false;
}
//
/*System.out.println("after setup " + (was ? "(dirty)" : "") + " , alltimes " + Arrays.toString(alltimes)
+ " times " + Arrays.toString(times) + " values " + Arrays.toString(values) +
" inds " + Arrays.toString(indicatorParameter.getParameterValues())) ;*/
}
private int ti2f(int i, int j) {
return (i == 0) ? j : 2 * i + j + 1;
}
private void
bestLinearFit(double[] xs, double[] ys, double[] use, double[] ot, double[] oz) {
assert (xs.length + 1) == ys.length;
assert ys.length == use.length + 2 || ys.length == use.length + 1;
int N = ys.length;
if (N == 2) {
// cheaper
assert xs.length == ot.length;
assert ys.length == oz.length;
System.arraycopy(xs, 0, ot, 0, xs.length);
System.arraycopy(ys, 0, oz, 0, ys.length);
//return new VDdemographicFunction(xs, ys, getUnits());
}
List<Integer> iv = new ArrayList<Integer>(2);
iv.add(0);
for (int k = 0; k < N - 1; ++k) {
if (use[k] > 0) {
iv.add(k + 1);
}
}
// iv.add(N-1);
double[] ati = new double[xs.length + 1];
ati[0] = 0.0;
System.arraycopy(xs, 0, ati, 1, xs.length);
int n = iv.size();
double[] a = new double[3 * n];
double[] v = new double[n];
for (int k = 0; k < n - 1; ++k) {
int i0 = iv.get(k);
int i1 = iv.get(k + 1);
double u0 = ati[i0];
double u1 = ati[i1] - ati[i0];
// on last interval add data for last point
// if( i1 == N-1 ) {
// i1 += 1;
// }
final int l = ti2f(k, k);
final int l1 = ti2f(k + 1, k);
for (int j = i0; j < i1; ++j) {
double t = ati[j];
double y = ys[j];
double z = (t - u0) / u1;
v[k] += y * (1 - z);
a[l] += (1 - z) * (1 - z);
a[l + 1] += z * (1 - z);
a[l1] += z * (1 - z);
a[l1 + 1] += z * z;
v[k + 1] += y * z;
}
}
{
int k = n - 1;
int i0 = iv.get(k);
int i1 = ys.length;
final int l = ti2f(k, k);
for (int j = i0; j < i1; ++j) {
a[l] += 1;
v[k] += ys[j];
}
}
for (int k = 0; k < n - 1; ++k) {
final double r = a[ti2f(k + 1, k)] / a[ti2f(k, k)];
for (int j = k; j < k + 3; ++j) {
a[ti2f((k + 1), j)] -= a[ti2f(k, j)] * r;
}
v[k + 1] -= v[k] * r;
}
if (oz.length != n) {
n = 3;
}
assert oz.length == n;
//double[] oz = new double[n];
for (int k = n - 1; k > 0; --k) {
oz[k] = v[k] / a[ti2f(k, k)];
v[k - 1] -= a[ti2f((k - 1), k)] * oz[k];
}
oz[0] = v[0] / a[ti2f(0, 0)];
// first and last in ot are reserved
assert ot.length - 2 == iv.size() - 1;
for (int j = 1; j < ot.length - 1; ++j) {
ot[j] = ati[iv.get(j)];
}
}
private int getIntervalIndexStep(final double t) {
int j = 0;
// ugly hack,
// when doubles are added in a different order and compared later, they can be a tiny bit off. With a
// stepwise model this creates a "one off" situation here, which is unpleasant.
// use float comparison here to avoid it
final float tf = (float) t;
while (tf > (float) times[j + 1]) ++j;
return j;
}
private int getIntervalIndexLin(final double t) {
int j = 0;
while (t > times[j + 1]) ++j;
return j;
}
private double linPop(double t) {
final int j = getIntervalIndexLin(t);
if (j == values.length - 1) {
return values[j];
}
final double a = (t - times[j]) / (intervals[j]);
return a * values[j + 1] + (1 - a) * values[j];
}
public double getDemographic(double t) {
double p;
switch (type) {
case STEPWISE: {
final int j = getIntervalIndexStep(t);
p = values[j];
break;
}
case LINEAR: {
p = linPop(t);
break;
}
case EXPONENTIAL: {
p = Math.exp(linPop(t));
break;
}
default:
throw new IllegalArgumentException("");
}
return p;
}
public double getIntensity(double t) {
return getIntegral(0, t);
}
public double getInverseIntensity(double x) {
assert false;
return 0;
}
private double intensityLinInterval(double start, double end, int index) {
final double dx = end - start;
if (dx == 0) {
return 0;
}
final double popStart = values[index];
final double popDiff = (index < values.length - 1) ? values[index + 1] - popStart : 0.0;
if (popDiff == 0.0) {
return dx / popStart;
}
final double time0 = times[index];
final double interval = intervals[index];
assert (float) start <= (float) (time0 + interval) && start >= time0 && (float) end <= (float) (time0 + interval) && end >= time0;
// final double pop0 = popStart + ((start - time0) / interval) * popDiff;
// final double pop1 = popStart + ((end - time0) / interval) * popDiff;
// do same as above more efficiently
// final double r = popDiff / interval;
// final double x = popStart - time0 * r;
// final double pop0 = x + start * r;
// final double pop1 = x + end * r;
//better numerical stability but not perfect
final double p1minusp0 = ((end - start) / interval) * popDiff;
final double v = interval * (popStart / popDiff);
final double p1overp0 = (v + (end - time0)) / (v + (start - time0));
if (p1minusp0 == 0.0 || p1overp0 <= 0) {
// either dx == 0 or is very small (numerical inaccuracy)
final double pop0 = popStart + ((start - time0) / interval) * popDiff;
return dx / pop0;
}
return dx * Math.log(p1overp0) / p1minusp0;
// return dx * Math.log(pop1/pop0) / (pop1 - pop0);*/
}
private double intensityLinInterval(int index) {
final double interval = intervals[index];
final double pop0 = values[index];
final double pop1 = values[index + 1];
if (pop0 == pop1) {
return interval / pop0;
}
return interval * Math.log(pop1 / pop0) / (pop1 - pop0);
}
private double intensityExpInterval(double start, double end, int index) {
final double pop0 = values[index];
if (index == intervals.length) {
// on last interval
return (end - start) / Math.exp(pop0);
}
final double interval = intervals[index];
final double pop1 = values[index + 1];
final double time0 = times[index];
assert start >= time0 && (float) start <= (float) (time0 + interval) && (float) end <= (float) (time0 + interval) && end >= time0;
final double a = (pop0 - pop1) / interval;
if (a == 0) {
return (end - start) / Math.exp(pop0);
}
return (Math.exp((end - time0) * a - pop0) - Math.exp((start - time0) * a - pop0)) / a;
}
private double intensityExpInterval(int index) {
final double interval = intervals[index];
final double pop0 = values[index];
final double pop1 = values[index + 1];
final double a = (pop0 - pop1) / interval;
if (a == 0) {
return interval / Math.exp(pop0);
}
return Math.exp(-pop0) / a * (Math.exp(interval * a) - 1);
}
public double getIntegral(double start, double finish) {
double intensity = 0.0;
switch (type) {
case STEPWISE: {
final int first = getIntervalIndexStep(start);
final int last = getIntervalIndexStep(finish);
final double popStart = values[first];
if (first == last) {
intensity = (finish - start) / popStart;
} else {
intensity = (times[first + 1] - start) / popStart;
for (int k = first + 1; k < last; ++k) {
intensity += intervals[k] / values[k];
}
intensity += (finish - times[last]) / values[last];
}
break;
}
case LINEAR: {
final int first = getIntervalIndexLin(start);
final int last = getIntervalIndexLin(finish);
if (first == last) {
intensity += intensityLinInterval(start, finish, first);
} else {
// from first to end of interval
intensity += intensityLinInterval(start, times[first + 1], first);
// intervals until (not including) last
for (int k = first + 1; k < last; ++k) {
intensity += intensityLinInterval(k);
}
// last interval
intensity += intensityLinInterval(times[last], finish, last);
}
break;
}
case EXPONENTIAL: {
final int first = getIntervalIndexLin(start);
final int last = getIntervalIndexLin(finish);
if (first == last) {
intensity += intensityExpInterval(start, finish, first);
} else {
// from first to end of interval
intensity += intensityExpInterval(start, times[first + 1], first);
// intervals until (not including) last
for (int k = first + 1; k < last; ++k) {
intensity += intensityExpInterval(k);
}
// last interval
intensity += intensityExpInterval(times[last], finish, last);
}
break;
}
}
return intensity;
}
public int getNumArguments() {
assert false;
return 0;
}
public String getArgumentName(int n) {
assert false;
return null;
}
public double getArgument(int n) {
assert false;
return 0;
}
public void setArgument(int n, double value) {
assert false;
}
public double getLowerBound(int n) {
return 0.0;
}
public double getUpperBound(int n) {
return Double.POSITIVE_INFINITY;
}
public DemographicFunction getCopy() {
return null;
}
public double getThreshold() {
return 1E-12;
}
// not sure why we need this here
public double value(double x) {
return 1.0 / getDemographic(x);
}
public TreeIntervals getTreeIntervals(int nt) {
return ti[nt];
}
public double[] allTimePoints() {
return alltimes;
}
public double[] times() {
double[] valuesCopy = new double[times.length -2];
System.arraycopy(times, 1, valuesCopy, 0, times.length-2);
// defensive copy
return valuesCopy;
}
/**
* @return population values transformed depending on type (i.e. exp(value) for Type.EXPONENTIAL)
*/
public double[] values() {
double[] valuesCopy = new double[values.length];
if (type == VariableDemographicModel.Type.EXPONENTIAL) {
for (int i = 0; i < values.length; i++) {
valuesCopy[i] = Math.exp(values[i]);
}
} else {
System.arraycopy(values, 0, valuesCopy, 0, values.length);
}
return valuesCopy;
}
// public String toString() {
// final StringBuilder sb = new StringBuilder(32);
//
// for (int k = 1; k < times.length - 1; ++k) {
// if (k > 1) {
// sb.append(",");
// }
// sb.append(times[k]);
// }
// sb.append("|");
// sb.append(type == VariableDemographicModel.Type.EXPONENTIAL ? Math.exp(values[0]) : values[0]);
// for (int k = 1; k < values.length; ++k) {
//
// sb.append(",");
// final double value = values[k];
// sb.append(type == VariableDemographicModel.Type.EXPONENTIAL ? Math.exp(value) : value);
// }
// return sb.toString();
// }
public double naturalLimit() {
return times[times.length - 2];
}
}