/*
* BranchRatePlotter.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.app.tools;
import dr.evolution.io.Importer;
import dr.evolution.io.NexusImporter;
import dr.evolution.io.TreeImporter;
import dr.evolution.io.NewickImporter;
import dr.evolution.tree.*;
import dr.app.gui.tree.JTreeDisplay;
import dr.app.gui.tree.JTreePanel;
import dr.app.gui.tree.SquareTreePainter;
import dr.stats.DiscreteStatistics;
import javax.swing.*;
import java.awt.*;
import java.awt.print.PrinterJob;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Set;
import java.util.StringTokenizer;
/**
* Date: Nov 30, 2004
* Time: 5:29:36 PM
*
* @author Alexei Drummond
*
* @version $Id: BranchRatePlotter.java,v 1.6 2005/12/08 13:52:46 rambaut Exp $
*/
public class BranchRatePlotter {
public static void main(String[] args) throws java.io.IOException, Importer.ImportException {
String controlFile = args[0];
//String treeFile1 = args[0];
//String treeFile2 = args[1];
String targetTreeFile = args[1];
int burnin = 0;
if (args.length > 2) {
burnin = Integer.parseInt(args[2]);
}
System.out.println("Ignoring first " + burnin + " trees as burnin.");
BufferedReader readerTarget = new BufferedReader(new FileReader(targetTreeFile));
String lineTarget = readerTarget.readLine();
readerTarget.close();
TreeImporter targetImporter;
if (lineTarget.toUpperCase().startsWith("#NEXUS")) {
targetImporter = new NexusImporter(new FileReader(targetTreeFile));
} else {
targetImporter = new NewickImporter(new FileReader(targetTreeFile));
}
MutableTree targetTree = new FlexibleTree(targetImporter.importNextTree());
targetTree = TreeUtils.rotateTreeByComparator(targetTree, TreeUtils.createNodeDensityComparator(targetTree));
BufferedReader reader = new BufferedReader(new FileReader(controlFile));
String line = reader.readLine();
int totalTrees = 0;
int totalTreesUsed = 0;
while (line != null) {
StringTokenizer tokens = new StringTokenizer(line);
NexusImporter importer1 = new NexusImporter(new FileReader(tokens.nextToken()));
NexusImporter importer2 = new NexusImporter(new FileReader(tokens.nextToken()));
int fileTotalTrees = 0;
while (importer1.hasTree()) {
Tree timeTree = importer1.importNextTree();
Tree mutationTree = importer2.importNextTree();
if (fileTotalTrees >= burnin) {
annotateRates(targetTree, targetTree.getRoot(), timeTree, mutationTree);
totalTreesUsed += 1;
}
totalTrees += 1;
fileTotalTrees += 1;
}
line = reader.readLine();
}
System.out.println("Total trees read: " + totalTrees);
System.out.println("Total trees summarized: " + totalTreesUsed);
// collect all rates
double mutations = 0.0;
double time = 0.0;
double[] rates = new double[targetTree.getNodeCount()-1];
int index = 0;
for (int i = 0; i < targetTree.getNodeCount(); i++) {
NodeRef node = targetTree.getNode(i);
if (!targetTree.isRoot(node)) {
Integer count = ((Integer)targetTree.getNodeAttribute(node,"count"));
if (count == null) {
throw new RuntimeException("Count missing from node in target tree");
}
if (!targetTree.isExternal(node)) {
double prob = (double) (int) count /(double)(totalTreesUsed);
if (prob >= 0.5) {
String label = ""+(Math.round(prob*100)/100.0);
targetTree.setNodeAttribute(node, "label", label);
}
}
Number totalMutations = (Number)targetTree.getNodeAttribute(node,"totalMutations");
Number totalTime = (Number)targetTree.getNodeAttribute(node,"totalTime");
mutations += totalMutations.doubleValue();
time += totalTime.doubleValue();
rates[index] = totalMutations.doubleValue()/totalTime.doubleValue();
System.out.println(totalMutations.doubleValue() + " / " + totalTime.doubleValue() + " = " + rates[index]);
targetTree.setNodeRate(node, rates[index]);
index += 1;
}
}
double minRate = DiscreteStatistics.min(rates);
double maxRate = DiscreteStatistics.max(rates);
double medianRate = DiscreteStatistics.median(rates);
//double topThird = DiscreteStatistics.quantile(2.0/3.0,rates);
//double bottomThird = DiscreteStatistics.quantile(1.0/3.0,rates);
//double unweightedMeanRate = DiscreteStatistics.mean(rates);
double meanRate = mutations/time;
System.out.println(minRate + "\t" + maxRate + "\t" + medianRate + "\t" + meanRate);
for (int i = 0; i < targetTree.getNodeCount(); i++) {
NodeRef node = targetTree.getNode(i);
if (!targetTree.isRoot(node)) {
double rate = targetTree.getNodeRate(node);
//double branchTime = ((Number)targetTree.getNodeAttribute(node, "totalTime")).doubleValue();
//double branchMutations = ((Number)targetTree.getNodeAttribute(node, "totalMutations")).doubleValue();
float relativeRate = (float)(rate /maxRate);
float radius = (float)Math.sqrt(relativeRate*36.0);
/*
float relativeRateZero = (float)((rate - minRate) / (maxRate - minRate));
float red = 0.0f;
float green = 0.0f;
float blue = 0.0f;
if (relativeRateZero < 0.5f) {
blue = 1.0f - (relativeRateZero * 2.0f);
green = 1.0f - blue;
} else {
red = (relativeRateZero - 0.5f) * 2.0f;
green = 1.0f - red;
}
*/
//System.out.println(red + " " + green + " " + blue);
//float lineThickness = relativeRate*6.0f;
if (rate > meanRate) {
targetTree.setNodeAttribute(node, "color", new Color(1.0f, 0.5f, 0.5f));
} else {
targetTree.setNodeAttribute(node, "color", new Color(0.5f, 0.5f, 1.0f));
}
//targetTree.setNodeAttribute(node, "color", new Color(red, green, blue));
targetTree.setNodeAttribute(node, "line", new BasicStroke(1.0f));
targetTree.setNodeAttribute(node, "shape",
new java.awt.geom.Ellipse2D.Double(0,0,radius*2.0,radius*2.0));
}
java.util.List heightList = (java.util.List)targetTree.getNodeAttribute(node, "heightList");
if (heightList != null) {
double[] heights = new double[heightList.size()];
for (int j = 0; j < heights.length; j++) {
heights[j] = (Double) heightList.get(j);
}
targetTree.setNodeHeight(node, DiscreteStatistics.mean(heights));
//if (heights.length >= (totalTreesUsed/2)) {
targetTree.setNodeAttribute(node, "nodeHeight.mean", DiscreteStatistics.mean(heights));
targetTree.setNodeAttribute(node, "nodeHeight.hpdUpper", DiscreteStatistics.quantile(0.975, heights));
targetTree.setNodeAttribute(node, "nodeHeight.hpdLower", DiscreteStatistics.quantile(0.025, heights));
//targetTree.setNodeAttribute(node, "nodeHeight.max", new Double(DiscreteStatistics.max(heights)));
//targetTree.setNodeAttribute(node, "nodeHeight.min", new Double(DiscreteStatistics.min(heights)));
//}
}
}
StringBuffer buffer = new StringBuffer();
writeTree(targetTree, targetTree.getRoot(), buffer, true, false);
buffer.append(";\n");
writeTree(targetTree, targetTree.getRoot(), buffer, false, true);
buffer.append(";\n");
System.out.println(buffer.toString());
SquareTreePainter treePainter = new SquareTreePainter();
treePainter.setColorAttribute("color");
treePainter.setLineAttribute("line");
// treePainter.setShapeAttribute("shape");
// treePainter.setLabelAttribute("label");
JTreeDisplay treeDisplay = new JTreeDisplay(treePainter,targetTree);
JTreePanel treePanel = new JTreePanel(treeDisplay);
JFrame frame = new JFrame();
frame.setSize(800,600);
frame.getContentPane().setLayout(new BorderLayout());
frame.getContentPane().add(treePanel);
frame.setVisible(true);
PrinterJob printJob = PrinterJob.getPrinterJob();
printJob.setPrintable(treeDisplay);
if (printJob.printDialog()){
try{
printJob.print();
}
catch(Exception ex){
throw new RuntimeException(ex);
}
}
}
private static void writeTree(Tree tree, NodeRef node, StringBuffer buffer, boolean rates, boolean labels) {
if (tree.isExternal(node)) {
buffer.append(tree.getTaxonId(node.getNumber()));
} else {
buffer.append("(");
writeTree(tree, tree.getChild(node, 0), buffer, rates, labels);
for (int i = 1; i < tree.getChildCount(node); i++) {
buffer.append(",");
writeTree(tree, tree.getChild(node, i), buffer, rates, labels);
}
buffer.append(")");
}
NodeRef parent = tree.getParent(node);
if (parent != null) {
double totalMutations = (Double) tree.getNodeAttribute(node, "totalMutations");
double totalTime = (Double) tree.getNodeAttribute(node, "totalTime");
double rate = totalMutations/totalTime;
int count = (Integer) tree.getNodeAttribute(node, "count");
if (rates) {
buffer.append(":").append(String.valueOf(rate));
} else {
buffer.append(":").append(String.valueOf(count));
}
}
}
private static void annotateRates(
MutableTree targetTree, NodeRef node, Tree timeTree, Tree mutationTree) {
Set<String> leafSet = TreeUtils.getDescendantLeaves(targetTree, node);
if (TreeUtils.isMonophyletic(timeTree, leafSet)) {
NodeRef timeNode = TreeUtils.getCommonAncestorNode(timeTree, leafSet);
NodeRef mutationNode = TreeUtils.getCommonAncestorNode(mutationTree, leafSet);
double height = timeTree.getNodeHeight(timeNode);
if (!targetTree.isRoot(node)) {
double time = timeTree.getNodeHeight(timeTree.getParent(timeNode)) - height;
double mutations = mutationTree.getNodeHeight(mutationTree.getParent(mutationNode)) - mutationTree.getNodeHeight(mutationNode);
//double rate = mutations/time;
Number totalMutations = (Number)targetTree.getNodeAttribute(node, "totalMutations");
Number totalTime = (Number)targetTree.getNodeAttribute(node, "totalTime");
if (totalMutations == null) {
targetTree.setNodeAttribute(node, "totalMutations", mutations);
targetTree.setNodeAttribute(node, "totalTime", time);
targetTree.setNodeAttribute(node, "count", 1);
} else {
Integer count = (Integer)targetTree.getNodeAttribute(node, "count");
targetTree.setNodeAttribute(node, "totalMutations", totalMutations.doubleValue() + mutations);
targetTree.setNodeAttribute(node, "totalTime", totalTime.doubleValue() + time);
targetTree.setNodeAttribute(node, "count", count + 1);
}
}
if (!targetTree.isExternal(node)) {
java.util.List<Double> list = (java.util.List<Double>)targetTree.getNodeAttribute(node, "heightList");
if (list == null) {
list = new ArrayList<Double>() ;
targetTree.setNodeAttribute(node, "heightList", list);
}
list.add(height);
}
}
for (int i = 0; i < targetTree.getChildCount(node); i++) {
annotateRates(targetTree, targetTree.getChild(node, i), timeTree, mutationTree);
}
}
}