/*
* SkylineReconstructor.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.app.tools;
import dr.inference.trace.LogFileTraces;
import dr.inference.trace.TraceDistribution;
import dr.inference.trace.TraceException;
import dr.inference.trace.TraceList;
import dr.stats.Variate;
import jebl.evolution.coalescent.IntervalList;
import jebl.evolution.coalescent.Intervals;
import jebl.evolution.io.ImportException;
import jebl.evolution.io.NewickImporter;
import jebl.evolution.io.NexusImporter;
import jebl.evolution.io.TreeImporter;
import jebl.evolution.trees.RootedTree;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
/**
* @author Andrew Rambaut
* @version $Id$
*/
public class SkylineReconstructor {
private int binCount;
private double minTime;
private double maxTime;
private double ageOfYoungest;
private Variate xData = new Variate.D();
private Variate yDataMean = new Variate.D();
private Variate yDataMedian = new Variate.D();
private Variate yDataUpper = new Variate.D();
private Variate yDataLower = new Variate.D();
public SkylineReconstructor(File logFile, File treeFile, int burnin,
int binCount, double minTime, double maxTime, double ageOfYoungest)
throws IOException, ImportException, TraceException {
this.binCount = binCount;
this.minTime = minTime;
this.maxTime = maxTime;
this.ageOfYoungest = ageOfYoungest;
LogFileTraces traces = new LogFileTraces(logFile.getName(), logFile);
traces.loadTraces();
traces.setBurnIn(burnin);
int stateCount = traces.getStateCount();
int firstPopSize = findArgument(traces, "popSize");
int popSizeCount = getTraceRange(traces, firstPopSize);
int firstGroupSize = findArgument(traces, "groupSize");
int groupSizeCount = getTraceRange(traces, firstGroupSize);
boolean isLinear = (groupSizeCount == popSizeCount - 1);
if (!isLinear && groupSizeCount != popSizeCount) {
if (isLinear) {
if (groupSizeCount != popSizeCount - 1) {
throw new TraceException("For the stepwise (constant) Bayesian skyline model there should " +
"be the same number of group size as population size parameters.");
}
}
}
ArrayList<ArrayList> popSizes = new ArrayList<ArrayList>();
ArrayList<ArrayList> groupSizes = new ArrayList<ArrayList>();
for (int i = 0; i < popSizeCount; i++) {
popSizes.add(new ArrayList(traces.getValues(firstPopSize + i)));
}
for (int i = 0; i < groupSizeCount; i++) {
groupSizes.add(new ArrayList(traces.getValues(firstGroupSize + i)));
}
List heights = traces.getValues(traces.getTraceIndex("treeModel.rootHeight"));
TraceDistribution distribution = new TraceDistribution(heights,
traces.getTrace(traces.getTraceIndex("treeModel.rootHeight")).getTraceType());
double timeMean = distribution.getMean();
double timeMedian = distribution.getMedian();
double timeUpper = distribution.getUpperHPD();
double timeLower = distribution.getLowerHPD();
// double maxHeight = timeLower;
double maxHeight = maxTime;
// switch () {
// // setting a timeXXXX to -1 means that it won't be displayed...
// case 0:
// maxHeight = timeLower;
// break;
// case 1:
// maxHeight = timeMedian;
// break;
// case 2:
// maxHeight = timeMean;
// break;
// case 3:
// maxHeight = timeUpper;
// break;
// }
BufferedReader reader = new BufferedReader(new FileReader(treeFile));
String line = reader.readLine();
TreeImporter importer;
if (line.toUpperCase().startsWith("#NEXUS")) {
importer = new NexusImporter(reader);
} else {
importer = new NewickImporter(reader, false);
}
double delta = maxHeight / (binCount - 1);
int skip = (int) (burnin / traces.getStepSize());
int state = 0;
while (importer.hasTree() && state < skip) {
importer.importNextTree();
state += 1;
}
// the age of the end of this group
double[][] groupTimes = new double[stateCount][];
//int tips = 0;
state = 0;
while (importer.hasTree()) {
RootedTree tree = (RootedTree) importer.importNextTree();
IntervalList intervals = new Intervals(tree);
int intervalCount = intervals.getIntervalCount();
//tips = tree.getExternalNodes().size();
// get the coalescent intervales only
groupTimes[state] = new double[groupSizeCount];
double totalTime = 0.0;
int groupSize = 1;
int groupIndex = 0;
int subIndex = 0;
if (firstGroupSize > 0) {
double g = (Double) groupSizes.get(groupIndex).get(state);
if (g != Math.round(g)) {
throw new RuntimeException("Group size " + groupIndex + " should be integer but found:" + g);
} else groupSize = (int) Math.round(g);
}
for (int j = 0; j < intervalCount; j++) {
totalTime += intervals.getInterval(j);
if (intervals.getIntervalType(j) == IntervalList.IntervalType.COALESCENT) {
subIndex += 1;
if (subIndex == groupSize) {
groupTimes[state][groupIndex] = totalTime;
subIndex = 0;
groupIndex += 1;
if (groupIndex < groupSizeCount) {
double g = (Double) groupSizes.get(groupIndex).get(state);
if (g != Math.round(g)) {
throw new RuntimeException("Group size " + groupIndex + " should be integer but found:" + g);
} else groupSize = (int) Math.round(g);
}
}
}
// insert zero-length coalescent intervals
int diff = intervals.getCoalescentEvents(j) - 1;
if (diff > 0)
throw new RuntimeException("Don't handle multifurcations!");
}
state += 1;
}
Variate[] bins = new Variate[binCount];
double height = 0.0;
for (int k = 0; k < binCount; k++) {
bins[k] = new Variate.D();
if (height >= 0.0 && height <= maxHeight) {
for (state = 0; state < stateCount; state++) {
if (isLinear) {
double lastGroupTime = 0.0;
int index = 0;
while (index < groupTimes[state].length && groupTimes[state][index] < height) {
lastGroupTime = groupTimes[state][index];
index += 1;
}
if (index < groupTimes[state].length - 1) {
double t = (height - lastGroupTime) / (groupTimes[state][index] - lastGroupTime);
double p1 = (Double) groupSizes.get(index).get(state);
double p2 = (Double) groupSizes.get(index + 1).get(state);
double popsize = p1 + ((p2 - p1) * t);
bins[k].add(popsize);
}
} else {
int index = 0;
while (index < groupTimes[state].length && groupTimes[state][index] < height) {
index += 1;
}
if (index < groupTimes[state].length) {
double popSize = (Double) groupSizes.get(index).get(state);
if (popSize == 0.0) {
throw new RuntimeException("Zero pop size");
}
bins[k].add(popSize);
} else {
// Do we really want to do this?
// bins[k].add(getPopSize(popSizeCount - 1,state));
}
}
}
}
height += delta;
}
double t = 0.0;
for (Variate bin : bins) {
xData.add(t);
if (bin.getCount() > 0) {
yDataMean.add(bin.getMean());
yDataMedian.add(bin.getQuantile(0.5));
yDataLower.add(bin.getQuantile(0.025));
yDataUpper.add(bin.getQuantile(0.975));
} else {
yDataMean.add(Double.NaN);
yDataMedian.add(Double.NaN);
yDataLower.add(Double.NaN);
yDataUpper.add(Double.NaN);
}
t += delta;
}
}
public Variate getXData() {
return xData;
}
public Variate getYDataMean() {
return yDataMean;
}
public Variate getYDataUpper() {
return yDataUpper;
}
public Variate getYDataMedian() {
return yDataMedian;
}
public Variate getYDataLower() {
return yDataLower;
}
private int findArgument(TraceList traceList, String argument) {
for (int j = 0; j < traceList.getTraceCount(); j++) {
String statistic = traceList.getTraceName(j);
String suffix = getNumericalSuffix(statistic);
if ((suffix.length() == 0 || suffix.equals("1")) && statistic.substring(0, statistic.length() - 1).contains(argument)) {
return j;
}
}
return -1;
}
private String getNumericalSuffix(String argument) {
int i = argument.length() - 1;
if (i < 0) return "";
char ch = argument.charAt(i);
if (!Character.isDigit(ch)) return "";
while (i > 0 && Character.isDigit(ch)) {
i -= 1;
ch = argument.charAt(i);
}
return argument.substring(i + 1, argument.length());
}
private int getTraceRange(TraceList traceList, int first) {
int i = 1;
int k = first;
String name = traceList.getTraceName(first);
String root = name.substring(0, name.length() - 1);
while (k < traceList.getTraceCount() && traceList.getTraceName(k).equals(root + i)) {
i++;
k++;
}
return i - 1;
}
public static void main(String[] argv) {
Variate x = null;
List<Variate> plots = new ArrayList<Variate>();
for (int i = 1; i <= 200; i++) {
String stem = "sim" + (i < 10 ? "00" : (i < 100 ? "0" : "")) + i;
try {
SkylineReconstructor skyline = new SkylineReconstructor(
new File(stem + ".log"),
new File(stem + ".trees"),
1000000,
200,
0.0,
150000,
0.0
);
if (x == null) {
x = skyline.getXData();
}
plots.add(skyline.getYDataMean());
} catch (IOException e) {
e.printStackTrace();
} catch (ImportException e) {
e.printStackTrace();
} catch (TraceException e) {
e.printStackTrace();
}
if (i % 10 == 0) {
System.err.println("Read " + i);
}
}
for (int i = 0; i < x.getCount(); i++) {
System.out.print(x.get(i));
for (Variate y : plots) {
System.out.print("\t" + y.get(i));
}
System.out.println();
}
}
}