/*
* ARGTraceAnalysis.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.arg;
import dr.evolution.io.Importer;
import dr.evolution.tree.FlexibleTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evoxml.util.GraphMLUtils;
import dr.util.FrequencySet;
import dr.util.NumberFormatter;
import java.io.IOException;
import java.io.Reader;
import java.util.Set;
/**
* @author Marc Suchard
*/
public class ARGTraceAnalysis {
public ARGTraceAnalysis(ARGTrace[] trace, int burnIn, boolean verbose) {
this.trace = trace;
int minMaxState = Integer.MAX_VALUE;
for (ARGTrace aTrace : trace) {
if (aTrace.getMaximumState() < minMaxState) {
minMaxState = aTrace.getMaximumState();
}
}
if (burnIn < 0 || burnIn >= minMaxState) {
this.burnin = minMaxState / (10 * trace[0].getStepSize());
if (verbose)
System.out.println("WARNING: Burn-in larger than total number of states - using 10% of smallest trace");
} else {
this.burnin = burnIn;
}
analyze(verbose);
}
/**
* Actually analyzes the trace given the burnin
*/
public void analyze(boolean verbose) {
if (verbose) {
if (trace.length > 1) System.out.println("Combining " + trace.length + " traces.");
}
ARGModel arg = getARG(0);
argSet = new FrequencySet<ARGModel>();
argSet.add(arg);
for (ARGTrace aTrace : trace) {
int treeCount = aTrace.getTreeCount(burnin * aTrace.getStepSize());
double stepSize = treeCount / 60.0;
int counter = 1;
if (verbose) {
System.out.println("Analyzing " + treeCount + " ARGs...");
System.out.println("0 25 50 75 100");
System.out.println("|--------------|--------------|--------------|--------------|");
System.out.print("*");
}
for (int i = 1; i < treeCount; i++) {
arg = aTrace.getARG(i, burnin * aTrace.getStepSize());
// cladeSet.add(tree);
argSet.add(arg);
if (i >= (int) Math.round(counter * stepSize) && counter <= 60) {
if (verbose) {
System.out.print("*");
System.out.flush();
}
counter += 1;
}
}
if (verbose) {
System.out.println("*");
}
}
}
/**
* Actually analyzes a particular tree using the trace given the burnin
*/
public final Tree analyzeARG(String target) {
int n = getTreeCount();
FlexibleTree meanTree = null;
for (int i = 0; i < n; i++) {
Tree tree = getARG(i);
if (TreeUtils.uniqueNewick(tree, tree.getRoot()).equals(target)) {
meanTree = new FlexibleTree(tree);
break;
}
}
if (meanTree == null) throw new RuntimeException("No target tree in trace");
int m = meanTree.getInternalNodeCount();
for (int j = 0; j < m; j++) {
double[] heights = new double[n];
NodeRef node1 = meanTree.getInternalNode(j);
Set<String> leafSet = TreeUtils.getDescendantLeaves(meanTree, node1);
for (int i = 0; i < n; i++) {
Tree tree = getARG(i);
NodeRef node2 = TreeUtils.getCommonAncestorNode(tree, leafSet);
heights[i] = tree.getNodeHeight(node2);
}
meanTree.setNodeHeight(node1, dr.stats.DiscreteStatistics.mean(heights));
meanTree.setNodeAttribute(node1, "upper", new Double(dr.stats.DiscreteStatistics.quantile(0.975, heights)));
meanTree.setNodeAttribute(node1, "lower", new Double(dr.stats.DiscreteStatistics.quantile(0.025, heights)));
}
return meanTree;
}
public final int getTreeCount() {
int treeCount = 0;
for (ARGTrace aTrace : trace) {
treeCount += aTrace.getTreeCount(burnin * aTrace.getStepSize());
}
return treeCount;
}
public final ARGModel getARG(int index) {
int oldTreeCount = 0;
int newTreeCount = 0;
for (ARGTrace aTrace : trace) {
newTreeCount += aTrace.getTreeCount(burnin * aTrace.getStepSize());
if (index < newTreeCount) {
return aTrace.getARG(index - oldTreeCount, burnin * aTrace.getStepSize());
}
oldTreeCount = newTreeCount;
}
throw new RuntimeException("Couldn't find ARG " + index);
}
public void report() throws IOException {
int fieldWidth = 14;
NumberFormatter formatter = new NumberFormatter(6);
formatter.setPadding(true);
formatter.setFieldWidth(fieldWidth);
int n = argSet.size();
int totalTrees = argSet.getSumFrequency();
System.out.println();
System.out.println("burnIn=" + burnin);
System.out.println("total ARGs used =" + totalTrees);
System.out.println();
System.out.println("95% credible set (" + n + " unique ARGs, " + totalTrees + " total):");
System.out.println("Count\tPercent\tARG");
int credSet = (95 * totalTrees) / 100;
int sumFreq = 0;
NumberFormatter nf = new NumberFormatter(8);
for (int i = 0; i < n; i++) {
int freq = argSet.getFrequency(i);
double prop = ((double) freq) / totalTrees;
System.out.print(freq);
System.out.print("\t" + nf.formatDecimal(prop * 100.0, 2) + "%");
sumFreq += freq;
double sumProp = ((double) sumFreq) / totalTrees;
System.out.print("\t" + nf.formatDecimal(sumProp * 100.0, 2) + "%");
String newickTree = //(String)argSet.get(i);
argSet.get(i).toString();
// if (freq > 100) {
// calculate conditional average node heights
// Tree meanTree = analyzeTree(newickTree);
// System.out.println("\t" + Tree.Utils.newick(meanTree));
/*for (int k = 0; k < meanTree.getInternalNodeCount(); k++) {
NodeRef node = meanTree.getInternalNode(k);
System.out.println("node " + k + "\t" +
meanTree.getNodeHeight(node) + "\t" +
meanTree.getNodeAttribute(node, "lower") + "\t" +
meanTree.getNodeAttribute(node, "upper"));
}*/
// } else {
System.out.println("\t" + newickTree);
// }
if (sumFreq >= credSet) {
System.out.println();
System.out.println("95% credible set has " + (i + 1) + " ARGs.");
break;
}
}
System.out.println();
System.out.println("DOT Format for most probable ARG:");
System.out.println(GraphMLUtils.dotFormat(
argSet.get(0).toXML())
);
// System.out.println(argSet.);
/* System.out.println("Majority rule clades (" + cladeSet.size() + " unique clades):");
n = cladeSet.size();
for (int i = 0; i < n; i++) {
int freq = cladeSet.getFrequency(i);
double prop = ((double)freq) / totalTrees;
if (prop > 0.5) {
System.out.print(freq);
System.out.print("\t" + nf.formatDecimal(prop * 100.0, 2) + "%");
System.out.println("\t" + cladeSet.getClade(i));
}
}*/
System.out.flush();
}
public void shortReport(String name, Tree tree, boolean drawHeader) throws IOException {
/* String targetTree = "";
if (tree != null) targetTree = Tree.Utils.uniqueNewick(tree, tree.getRoot());
int n = argSet.size();
int totalTrees = argSet.getSumFrequency();
double highestProp = ((double)argSet.getFrequency(0)) / totalTrees;
String mapTree = (String)argSet.get(0);
if (drawHeader) {
System.out.println("file\ttrees\tuniqueTrees\tp(MAP)\tMAP tree\t95credSize\ttrue_I\tp(true)\tcum(true)");
}
System.out.print(name+"\t");
System.out.print(totalTrees+"\t");
System.out.print(n+"\t");
System.out.print(highestProp +"\t");
System.out.print(mapTree +"\t");
int credSet = (95 * totalTrees) / 100;
int sumFreq = 0;
int credSetSize = -1;
int targetTreeIndex = -1;
double targetTreeProb = 0.0;
double targetTreeCum = 1.0;
for (int i = 0; i < n; i++) {
int freq = argSet.getFrequency(i);
double prop = ((double)freq) / totalTrees;
sumFreq += freq;
double sumProp = ((double)sumFreq) / totalTrees;
String newickTree = (String)argSet.get(i);
if (newickTree.equals(targetTree)) {
targetTreeIndex = i + 1;
targetTreeProb = prop;
targetTreeCum = sumProp;
}
if (sumFreq >= credSet) {
if (credSetSize == -1) credSetSize = i + 1;
}
}
System.out.print(credSetSize +"\t");
System.out.print(targetTreeIndex +"\t");
System.out.print(targetTreeProb +"\t");
System.out.println(targetTreeCum); */
}
public int getBurnin() {
return burnin;
}
/**
* @return an analyses of the trees in a log file.
*/
//private static TreeTraceAnalysis analyzeLogFile(String fileName, int burnin) throws IOException {
// return analyzeLogFile(new Reader[] {new FileReader(fileName)}, burnin, true);
//}
/**
* @return an analyses of the trees in a log file.
*/
public static ARGTraceAnalysis analyzeLogFile(Reader[] reader, int burnin, boolean verbose) throws IOException {
ARGTrace[] trace = new ARGTrace[reader.length];
for (int i = 0; i < reader.length; i++) {
try {
trace[i] = ARGTrace.loadARGTrace(reader[i]);
} catch (Importer.ImportException ie) {
throw new RuntimeException(ie.toString());
}
reader[i].close();
}
return new ARGTraceAnalysis(trace, burnin, verbose);
}
private int burnin = -1;
private ARGTrace[] trace;
// private CladeSet cladeSet;
private FrequencySet<ARGModel> argSet;
}