/*
* TreeTraceAnalysis.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.tree;
import dr.app.tools.NexusExporter;
import dr.evolution.io.Importer;
import dr.evolution.io.NewickImporter;
import dr.evolution.io.TreeTrace;
import dr.evolution.tree.*;
import dr.util.FrequencySet;
import dr.util.NumberFormatter;
import jebl.evolution.treemetrics.RobinsonsFouldMetric;
import java.io.IOException;
import java.io.PrintStream;
import java.io.Reader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
/**
* @author Alexei Drummond
* @author Andrew Rambaut
* @version $Id: TreeTraceAnalysis.java,v 1.20 2005/06/07 16:28:18 alexei Exp $
*/
public class TreeTraceAnalysis {
private TreeTraceAnalysis(TreeTrace[] traces, int burnIn, boolean verbose) {
this.traces = traces;
int minMaxState = Integer.MAX_VALUE;
for (TreeTrace trace : traces) {
if (trace.getMaximumState() < minMaxState) {
minMaxState = trace.getMaximumState();
}
}
if (burnIn < 0 || burnIn >= minMaxState) {
this.burnin = minMaxState / (10 * traces[0].getStepSize());
if (verbose) {
String reason = burnIn < 0 ? "Defalt burn-in" : "WARNING: Burn-in larger than total number of states";
System.out.println(reason + " - using 10% of smallest trace");
}
} else {
this.burnin = burnIn;
}
analyze(verbose);
}
public static double[] getSymmetricTreeDistanceTrace(TreeTrace treeTrace, Tree targetTree) {
double[] symDistance = new double[treeTrace.getTreeCount(0)];
RobinsonsFouldMetric metric = new RobinsonsFouldMetric();
jebl.evolution.trees.RootedTree jreference = TreeUtils.asJeblTree(targetTree);
for (int i = 0; i < symDistance.length; i++) {
final jebl.evolution.trees.RootedTree tree = TreeUtils.asJeblTree(treeTrace.getTree(i, 0));
symDistance[i] = metric.getMetric(jreference, tree);
}
return symDistance;
}
/**
* Actually analyzes the trace given the burnin
*
* @param verbose if true then progress is logged to stdout
*/
void analyze(boolean verbose) {
if (verbose) {
if (traces.length > 1) System.out.println("Combining " + traces.length + " traces.");
}
final Tree tree0 = getTree(0);
double[][] changed = new double[tree0.getNodeCount()][tree0.getNodeCount()];
double[] rateConditionalOnChange = new double[tree0.getNodeCount()];
boolean changesFound = false;
cladeSet = new CladeSet(tree0);
treeSet = new FrequencySet<String>();
treeSet.add(TreeUtils.uniqueNewick(tree0, tree0.getRoot()));
final int reportRate = 60;
for (TreeTrace trace : traces) {
final int treeCount = trace.getTreeCount(burnin * trace.getStepSize());
final double stepSize = treeCount / (double) reportRate;
int counter = 1;
if (verbose) {
System.out.println("Analyzing " + treeCount + " trees...");
System.out.println("0 25 50 75 100");
System.out.println("|--------------|--------------|--------------|--------------|");
System.out.print("*");
}
for (int i = 1; i < treeCount; i++) {
Tree tree = trace.getTree(i, burnin * trace.getStepSize());
for (int j = 0; j < tree.getNodeCount(); j++) {
if (tree.getNode(j) != tree.getRoot() && tree.getNodeAttribute(tree.getNode(j), "changed") != null) {
changesFound = true;
final Object o = tree.getNodeAttribute(tree.getNode(j), "changed");
if (o != null) {
boolean ch = getChanged(tree, j);
if (ch) {
rateConditionalOnChange[j] += (Double) tree.getNodeAttribute(tree.getNode(j), "rate");
}
for (int k = 0; k < tree.getNodeCount(); k++) {
if (tree.getNode(k) != tree.getRoot()) {
changed[j][k] += (ch && getChanged(tree, k)) ? 1 : 0;
}
}
}
}
}
cladeSet.add(tree);
treeSet.add(TreeUtils.uniqueNewick(tree, tree.getRoot()));
if (verbose && i >= (int) Math.round(counter * stepSize) && counter <= reportRate) {
System.out.print("*");
System.out.flush();
counter += 1;
}
}
if (verbose) {
System.out.println("*");
}
}
if (changesFound) {
for (int j = 0; j < tree0.getNodeCount(); j++) {
System.out.println(j + "\t" + rateConditionalOnChange[j]);
}
System.out.println();
for (int j = 0; j < tree0.getNodeCount(); j++) {
for (int k = 0; k < tree0.getNodeCount(); k++) {
System.out.print(changed[j][k] + "\t");
}
System.out.println();
}
}
}
private boolean getChanged(Tree tree, int j) {
final Object o = tree.getNodeAttribute(tree.getNode(j), "changed");
if (o instanceof Integer) return (Integer) o == 1;
return (Boolean) o;
}
/**
* Actually analyzes a particular tree using the trace given the burnin
*
* @param target a tree in uniqueNewick format
* @return a tree with mean node heights
*/
final MutableTree analyzeTree(String target) {
final int n = getTreeCount();
FlexibleTree meanTree = null;
// todo using CladeSet may probably speed this a lot
for (int i = 0; i < n; i++) {
final Tree tree = getTree(i);
if (TreeUtils.uniqueNewick(tree, tree.getRoot()).equals(target)) {
meanTree = new FlexibleTree(tree);
break;
}
}
if (meanTree == null) {
throw new RuntimeException("No target tree in trace");
}
final int inc = meanTree.getInternalNodeCount();
for (int j = 0; j < inc; j++) {
double[] heights = new double[n];
NodeRef nodej = meanTree.getInternalNode(j);
Set<String> leafSet = TreeUtils.getDescendantLeaves(meanTree, nodej);
for (int i = 0; i < n; i++) {
final Tree tree = getTree(i);
NodeRef can = TreeUtils.getCommonAncestorNode(tree, leafSet);
heights[i] = tree.getNodeHeight(can);
}
meanTree.setNodeHeight(nodej, dr.stats.DiscreteStatistics.mean(heights));
final double upper = dr.stats.DiscreteStatistics.quantile(0.975, heights);
meanTree.setNodeAttribute(nodej, "upper", upper);
final double lower = dr.stats.DiscreteStatistics.quantile(0.025, heights);
meanTree.setNodeAttribute(nodej, "lower", lower);
// Make it possible to display bars in figtree
meanTree.setNodeAttribute(nodej, "range", new Double[]{lower, upper});
}
return meanTree;
}
final int getTreeCount() {
int treeCount = 0;
for (TreeTrace trace : traces) {
treeCount += trace.getTreeCount(burnin * trace.getStepSize());
}
return treeCount;
}
final Tree getTree(int index) {
int oldTreeCount = 0;
int newTreeCount = 0;
for (TreeTrace trace : traces) {
final int br = burnin * trace.getStepSize();
newTreeCount += trace.getTreeCount(br);
if (index < newTreeCount) {
return trace.getTree(index - oldTreeCount, br);
}
oldTreeCount = newTreeCount;
}
throw new RuntimeException("Couldn't find tree " + index);
}
public void report(int minNT) throws IOException {
report(0.5, 0.95, minNT);
}
public void report(double minCladeProbability, int minNT) throws IOException {
report(minCladeProbability, 0.95, minNT);
}
/**
* @param minCladeProbability clades with at least this posterior probability will be included in report.
* @throws IOException if general I/O error occurs
*/
public void report(double minCladeProbability, double credSetProbability, int minNT) throws IOException {
System.err.println("making report");
final int fieldWidth = 14;
NumberFormatter formatter = new NumberFormatter(6);
formatter.setPadding(true);
formatter.setFieldWidth(fieldWidth);
final int nTreeSet = treeSet.size();
int totalTrees = treeSet.getSumFrequency();
System.out.println();
System.out.println("burnIn=" + burnin);
System.out.println("total trees used =" + totalTrees);
System.out.println();
System.out.println((Math.round(credSetProbability * 100.0))
+ "% credible set (" + nTreeSet + " unique trees, " + totalTrees + " total):");
System.out.println("Count\tPercent\tTree");
int credSet = (int) (credSetProbability * totalTrees);
int sumFreq = 0;
int skipped = 0;
NumberFormatter nf = new NumberFormatter(8);
for (int i = 0; i < nTreeSet; i++) {
final int freq = treeSet.getFrequency(i);
boolean show = true;
if( minNT > 0 && freq <= minNT ) {
show = false;
skipped += 1;
}
final double prop = ((double) freq) / totalTrees;
if( show ) {
System.out.print(freq);
System.out.print("\t" + nf.formatDecimal(prop * 100.0, 2) + "%");
}
sumFreq += freq;
final double sumProp = ((double) sumFreq) / totalTrees;
if( show ) {
System.out.print("\t" + nf.formatDecimal(sumProp * 100.0, 2) + "%");
String newickTree = treeSet.get(i);
if (freq > 100) {
// calculate conditional average node heights
Tree meanTree = analyzeTree(newickTree);
System.out.println("\t" + TreeUtils.newick(meanTree));
} else {
System.out.println("\t" + newickTree);
}
}
if (sumFreq >= credSet) {
if( skipped > 0 ) {
System.out.println();
System.out.println("... (" + skipped + ") trees.");
}
System.out.println();
System.out.println("95% credible set has " + (i + 1) + " trees.");
break;
}
}
System.out.println();
System.out.println(Math.round(minCladeProbability * 100.0) +
"%-rule clades (" + cladeSet.size() + " unique clades):");
final int nCladeSet = cladeSet.size();
for (int i = 0; i < nCladeSet; i++) {
final int freq = cladeSet.getFrequency(i);
final double prop = ((double) freq) / totalTrees;
if (prop >= minCladeProbability) {
System.out.print(freq);
System.out.print("\t" + nf.formatDecimal(prop * 100.0, 2) + "%");
System.out.print("\t" + cladeSet.getMeanNodeHeight(i));
System.out.println("\t" + cladeSet.getClade(i));
}
}
System.out.flush();
System.out.println("Clade credible sets:");
int fiveCredSet = (5 * totalTrees) / 100;
int halfCredSet = (50 * totalTrees) / 100;
sumFreq = 0;
assert nTreeSet == treeSet.size();
final CladeSet tempCladeSet = new CladeSet();
for (int nt = 0; nt < nTreeSet; nt++) {
sumFreq += treeSet.getFrequency(nt);
String newickTree = treeSet.get(nt);
NewickImporter importer = new NewickImporter(new StringReader(newickTree));
try {
Tree tree = importer.importNextTree();
tempCladeSet.add(tree);
} catch (Importer.ImportException e) {
System.err.println("Err");
}
if (sumFreq >= fiveCredSet) {
System.out.println();
System.out.println("5% credible set has " + tempCladeSet.getCladeCount() + " clades.");
// don't do it more than once
fiveCredSet = totalTrees + 1;
}
if (sumFreq >= halfCredSet) {
System.out.println();
System.out.println("50% credible set has " + tempCladeSet.getCladeCount() + " clades.");
// don't do it more than once
halfCredSet = totalTrees + 1;
}
}
System.out.flush();
}
public void shortReport(String name, Tree tree, boolean drawHeader) {
shortReport(name, tree, drawHeader, 0.95);
}
public void shortReport(String name, Tree tree, boolean drawHeader, double credSetProbability) {
String targetTree = "";
if (tree != null) targetTree = TreeUtils.uniqueNewick(tree, tree.getRoot());
final int n = treeSet.size();
final int totalTrees = treeSet.getSumFrequency();
final double highestProp = ((double) treeSet.getFrequency(0)) / totalTrees;
String mapTree = treeSet.get(0);
if (drawHeader) {
System.out.println("file\ttrees\tuniqueTrees\tp(MAP)\tMAP tree\t" + (int) credSetProbability * 100 + "credSize\ttrue_I\tp(true)\tcum(true)");
}
System.out.print(name + "\t");
System.out.print(totalTrees + "\t");
System.out.print(n + "\t");
System.out.print(highestProp + "\t");
System.out.print(mapTree + "\t");
int credSet = (int) (credSetProbability * totalTrees);
int sumFreq = 0;
int credSetSize = -1;
int targetTreeIndex = -1;
double targetTreeProb = 0.0;
double targetTreeCum = 1.0;
for (int i = 0; i < n; i++) {
final int freq = treeSet.getFrequency(i);
final double prop = ((double) freq) / totalTrees;
sumFreq += freq;
final double sumProp = ((double) sumFreq) / totalTrees;
String newickTree = treeSet.get(i);
if (newickTree.equals(targetTree)) {
targetTreeIndex = i + 1;
targetTreeProb = prop;
targetTreeCum = sumProp;
}
if (sumFreq >= credSet) {
if (credSetSize == -1) credSetSize = i + 1;
}
}
System.out.print(credSetSize + "\t");
System.out.print(targetTreeIndex + "\t");
System.out.print(targetTreeProb + "\t");
System.out.println(targetTreeCum);
}
public void export(PrintStream out, double minTreeProbability, int max, boolean verbose) {
NexusExporter exporter = new NexusExporter(out);
int n = treeSet.size();
if (max < 0) max = n;
final int totalTrees = treeSet.getSumFrequency();
List<Tree> trees = new ArrayList<Tree>();
final int totExport = Math.min(max, n);
final boolean progress = verbose && totExport > 60;
if (progress) {
System.out.println("Exporting " + totExport + " trees...");
System.out.println("0 25 50 75 100");
System.out.println("|--------------|--------------|--------------|--------------|");
System.out.print("*");
}
// todo have an option for threshold and sort by NCP.
for (int i = 0; i < n; i++) {
int freq = treeSet.getFrequency(i);
double prop = ((double) freq) / totalTrees;
if (prop < minTreeProbability) {
continue;
}
final String newickTree = treeSet.get(i);
// calculate conditional average node heights
final MutableTree tree = analyzeTree(newickTree);
tree.setAttribute("weight", prop);
double p = cladeSet.annotate(tree, "posterior");
tree.setNodeAttribute(tree.getRoot(), "posterior", Math.exp(p / tree.getInternalNodeCount()));
trees.add(tree);
if (progress && ((i + 1) % (totExport / 60)) == 0) {
System.out.print("*");
}
if (trees.size() == max) {
break;
}
}
if (trees.size() > 0) {
exporter.exportTrees(trees.toArray(new Tree[trees.size()]), true);
}
}
public int getBurnin() {
return burnin;
}
/**
* @param reader the readers to be analyzed
* @param burnin the burnin in states
* @param verbose true if progress should be logged to stdout
* @return an analyses of the trees in a log file.
* @throws java.io.IOException if general I/O error occurs
*/
public static TreeTraceAnalysis analyzeLogFile(Reader[] reader, int burnin, boolean verbose) throws IOException {
TreeTrace[] trace = new TreeTrace[reader.length];
for (int i = 0; i < reader.length; i++) {
try {
trace[i] = TreeTrace.loadTreeTrace(reader[i]);
} catch (Importer.ImportException ie) {
throw new RuntimeException(ie.toString());
}
reader[i].close();
}
return new TreeTraceAnalysis(trace, burnin, verbose);
}
private int burnin = -1;
private final TreeTrace[] traces;
private CladeSet cladeSet;
private FrequencySet<String> treeSet;
}