/*
* ConditionalCladeFrequency.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
/**
*
*/
package dr.evomodel.tree;
import dr.evolution.io.Importer;
import dr.evolution.io.NewickImporter;
import dr.evolution.io.NexusImporter;
import dr.evolution.io.TreeTrace;
import dr.evolution.tree.Clade;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.SimpleTree;
import dr.evolution.tree.Tree;
import dr.inference.model.Likelihood;
import dr.math.MathUtils;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.util.*;
/**
* @author Sebastian Hoehna
* <p/>
* This class calculates the conditional clade probabilities for a set of trees.
* It can be used to estimate the marginal posterior for a given tree.
*/
public class ConditionalCladeFrequency extends
AbstractCladeImportanceDistribution {
private double EPSILON;
private long samples = 0;
private HashMap<BitSet, Clade> cladeProbabilities;
private HashMap<BitSet, HashMap<BitSet, Clade>> cladeCoProbabilities;
private TreeTrace[] traces;
private int burnin;
/**
* @param epsilon - the default number of occurences for each clade which wasn't
* observed to guarantee non-zero probabilities
*/
public ConditionalCladeFrequency(Tree tree, double epsilon) {
// initializing global variables
cladeProbabilities = new HashMap<BitSet, Clade>();
cladeCoProbabilities = new HashMap<BitSet, HashMap<BitSet, Clade>>();
// setting global variables
EPSILON = epsilon;
}
/**
* @param traces - samples of trees in a tree traces array.
* @param epsilon - the default number of occurences for each clade which wasn't
* observed to guarantee non-zero probabilities
* @param burnIn - number of trees discarded from the trace
* @param verbose - hide the runtime status and outputs
*/
public ConditionalCladeFrequency(TreeTrace[] traces, double epsilon,
int burnIn, boolean verbose) {
// initializing global variables
cladeProbabilities = new HashMap<BitSet, Clade>();
cladeCoProbabilities = new HashMap<BitSet, HashMap<BitSet, Clade>>();
// setting global variables
EPSILON = epsilon;
this.traces = traces;
// calculates the burn-in to 10% if it was set out of the boundaries
int minMaxState = Integer.MAX_VALUE;
for (TreeTrace trace : traces) {
if (trace.getMaximumState() < minMaxState) {
minMaxState = trace.getMaximumState();
}
}
if (burnIn < 0 || burnIn >= minMaxState) {
this.burnin = minMaxState / (10 * traces[0].getStepSize());
if (verbose)
System.out
.println("WARNING: Burn-in larger than total number of states - using 10% of smallest trace");
} else {
this.burnin = burnIn;
}
// analyzing the whole trace -> reading the trees
analyzeTrace(verbose);
}
/**
* Actually analyzes the trace given the burn-in. Each tree from the trace
* is read and the conditional clade frequencies incremented.
*
* @param verbose if true then progress is logged to stdout
*/
public void analyzeTrace(boolean verbose) {
if (verbose) {
if (traces.length > 1)
System.out.println("Combining " + traces.length + " traces.");
}
// get first tree to extract the taxon
Tree tree = getTree(0);
// taxonMap = getTaxonMap(tree);
// read every tree from the trace
for (TreeTrace trace : traces) {
// do some output stuff
int treeCount = trace.getTreeCount(burnin * trace.getStepSize());
double stepSize = treeCount / 60.0;
int counter = 1;
if (verbose) {
System.out.println("Analyzing " + treeCount + " trees...");
System.out
.println("0 25 50 75 100");
System.out
.println("|--------------|--------------|--------------|--------------|");
System.out.print("*");
}
for (int i = 1; i < treeCount; i++) {
// get the next tree
tree = trace.getTree(i, burnin * trace.getStepSize());
// add the tree and its clades to the frequencies
addTree(tree);
// some more output stuff
if (i >= (int) Math.round(counter * stepSize) && counter <= 60) {
if (verbose) {
System.out.print("*");
System.out.flush();
}
counter += 1;
}
}
if (verbose) {
System.out.println("*");
}
}
}
/**
* Creates the report. The estimated posterior of the given tree is printed.
*
* @throws IOException if general I/O error occurs
*/
public void report(Reader r) throws IOException, Importer.ImportException {
System.err.println("making report");
ArrayList<Tree> referenceTrees = new ArrayList<Tree>();
BufferedReader reader = new BufferedReader(r);
String line = reader.readLine();
if (line.toUpperCase().startsWith("#NEXUS")) {
NexusImporter importer = new NexusImporter(reader);
Tree[] trees = importer.importTrees(null);
for (Tree tree : trees) {
referenceTrees.add(tree);
SimpleTree sTree = new SimpleTree(tree);
System.out.println("Estimated marginal posterior by condiational clade frequencies:");
System.out.println(getTreeProbability(sTree) + "\t\t" + sTree);
}
} else {
throw new RuntimeException("Could not read reference tree. Only Nexus format is supported.");
}
System.out.flush();
}
/**
* Calculates the probability of a given tree.
*
* @param tree - the tree to be analyzed
* @return estimated posterior probability in log
*/
public double getTreeProbability(Tree tree) {
double prob = 0.0;
List<Clade> clades = new ArrayList<Clade>();
List<Clade> parentClades = new ArrayList<Clade>();
// get clades contained in the tree
getNonComplementaryClades(tree, tree.getRoot(), parentClades, clades);
int size = clades.size();
// for every clade multiply its conditional clade probability to the tree probability
for (int i = 0; i < size; i++) {
Clade c = clades.get(i);
// get the bits of the clade
Clade parent = parentClades.get(i);
// set the occurrences to epsilon
double tmp = EPSILON;
double parentOccurrences = 0.0;
BitSet parentBits = parent.getBits();
if (cladeProbabilities.containsKey(parentBits)) {
// if we observed this clade in the trace, add the occurrences to epsilon
parentOccurrences += cladeProbabilities.get(parentBits)
.getSampleCount();
}
if (cladeCoProbabilities.containsKey(parentBits)) {
// if we observed the parent clade
HashMap<BitSet, Clade> conditionalProbs = cladeCoProbabilities
.get(parentBits);
BitSet bits = c.getBits();
if (conditionalProbs.containsKey(bits)) {
// if we observed this conditional clade in the trace, add the occurrences to epsilon
tmp += conditionalProbs.get(bits).getSampleCount();
}
}
// add epsilon for each clade
final double splits = Math.pow(2, parent.getSize() - 1) - 1;
parentOccurrences += EPSILON * splits;
// multiply the conditional clade probability to the tree
// probability
prob += Math.log(tmp / parentOccurrences);
}
return prob;
}
/**
* Calculates the probability of a given tree.
*
* @param tree - the tree to be analyzed
* @return estimated posterior probability in log
*/
public double getTreeProbability(Tree tree,
HashMap<String, Integer> taxonMap) {
double prob = 0.0;
List<Clade> clades = new ArrayList<Clade>();
List<Clade> parentClades = new ArrayList<Clade>();
// get clades contained in the tree
getNonComplementaryClades(tree, tree.getRoot(), parentClades, clades,
taxonMap);
int size = clades.size();
// for every clade multiply its conditional clade probability to the
// tree probability
for (int i = 0; i < size; i++) {
Clade c = clades.get(i);
// get the bits of the clade
Clade parent = parentClades.get(i);
// set the occurrences to epsilon
double tmp = EPSILON;
double parentOccurrences = 0.0;
BitSet parentBits = parent.getBits();
if (cladeProbabilities.containsKey(parentBits)) {
// if we observed this clade in the trace, add the
// occurrences
// to epsilon
parentOccurrences += cladeProbabilities.get(parentBits)
.getSampleCount();
}
if (cladeCoProbabilities.containsKey(parentBits)) {
// if we observed the parent clade
HashMap<BitSet, Clade> conditionalProbs = cladeCoProbabilities
.get(parentBits);
BitSet bits = c.getBits();
if (conditionalProbs.containsKey(bits)) {
// if we observed this conditional clade in the trace,
// add
// the occurrences to epsilon
tmp += conditionalProbs.get(bits).getSampleCount();
}
}
// add epsilon for each clade
final double splits = Math.pow(2, parent.getSize() - 1) - 1;
parentOccurrences += EPSILON * splits;
// multiply the conditional clade probability to the tree
// probability
prob += Math.log(tmp / parentOccurrences);
}
return prob;
}
public double splitClade(Clade parent, Clade[] children) {
// the number of all possible clades is 2^n with n the number of tips
// reduced by 2 because we wont consider the clades with all or no tips
// contained
// divide this number by 2 because every clade has a matching clade to
// form the split
// #splits = 2^(n-1) - 1
final double splits = Math.pow(2, parent.getSize() - 1) - 1;
double prob = 0;
if (cladeCoProbabilities.containsKey(parent.getBits())) {
HashMap<BitSet, Clade> childClades = cladeCoProbabilities
.get(parent.getBits());
double noChildClades = 0.0;
double sum = 0.0;
Set<BitSet> keys = childClades.keySet();
for (BitSet child : keys) {
Clade tmp = childClades.get(child);
if (parent.getSize() > tmp.getSize() + 1) {
sum += (tmp.getSampleCount() + EPSILON) / 2.0;
noChildClades += 0.5;
} else {
sum += (tmp.getSampleCount() + EPSILON);
noChildClades += 1.0;
}
}
// add epsilon for each not observed clade
sum += EPSILON * (splits - noChildClades);
// roulette wheel
double randomNumber = Math.random() * sum;
for (BitSet child : keys) {
Clade tmp = childClades.get(child);
if (parent.getSize() > tmp.getSize() + 1) {
randomNumber -= (tmp.getSampleCount() + EPSILON) / 2.0;
} else {
randomNumber -= (tmp.getSampleCount() + EPSILON);
}
if (randomNumber < 0) {
children[0] = tmp;
prob = (tmp.getSampleCount() + EPSILON) / sum;
break;
}
}
if (randomNumber >= 0) {
// randomNumber /= EPSILON;
prob = EPSILON / sum;
BitSet newChild;
BitSet inverseBits;
do {
do {
newChild = (BitSet) parent.getBits().clone();
int index = -1;
do {
index = newChild.nextSetBit(index + 1);
if (index > -1 && MathUtils.nextBoolean()) {
newChild.clear(index);
}
} while (index > -1);
} while (newChild.cardinality() == 0
|| newChild.cardinality() == parent.getSize());
inverseBits = (BitSet) newChild.clone();
inverseBits.xor(parent.getBits());
} while (childClades.containsKey(newChild)
|| childClades.containsKey(inverseBits));
Clade randomClade = new Clade(newChild, 0.9999 * parent
.getHeight());
children[0] = randomClade;
BitSet secondChild = (BitSet) children[0].getBits().clone();
secondChild.xor(parent.getBits());
children[1] = new Clade(secondChild, 0.9999 * parent
.getHeight());
} else {
BitSet secondChild = (BitSet) children[0].getBits().clone();
secondChild.xor(parent.getBits());
children[1] = childClades.get(secondChild);
if (children[1] == null) {
children[1] = new Clade(secondChild, 0.9999 * parent
.getHeight());
}
}
} else {
prob = 1.0 / splits;
BitSet newChild;
do {
newChild = (BitSet) parent.getBits().clone();
int index = -1;
do {
index = newChild.nextSetBit(index + 1);
if (index > -1 && MathUtils.nextBoolean()) {
newChild.clear(index);
}
} while (index > -1);
} while (newChild.cardinality() == 0
|| newChild.cardinality() == parent.getSize());
Clade randomClade = new Clade(newChild, 0.9999 * parent.getHeight());
// randomClade.addSample();
randomClade.addHeight(0.9999 * parent.getHeight());
children[0] = randomClade;
BitSet secondChild = (BitSet) children[0].getBits().clone();
secondChild.xor(parent.getBits());
children[1] = new Clade(secondChild, 0.9999 * parent.getHeight());
// children[1].addSample();
randomClade.addHeight(0.9999 * parent.getHeight());
}
return Math.log(prob);
}
public double getChanceForNodeHeights(TreeModel tree, Likelihood likelihood) {
double prob = 0.0;
NodeRef node = tree.getRoot();
Clade currentClade = getClade(tree, node);
int childcount = tree.getChildCount(node);
for (int i = 0; i < childcount; i++) {
NodeRef child = tree.getChild(node, i);
if (!tree.isExternal(child)) {
// prob += getChanceForNodeheights(tree, child, currentClade,
// likelihood, prior);
}
}
return prob;
}
public double setNodeHeights(TreeModel tree, Likelihood likelihood) {
double prob = 0.0;
NodeRef node = tree.getRoot();
Clade currentClade = getClade(tree, node);
int childcount = tree.getChildCount(node);
for (int i = 0; i < childcount; i++) {
NodeRef child = tree.getChild(node, i);
if (!tree.isExternal(child)) {
// prob += setNodeHeights(tree, child, currentClade, likelihood,
// prior);
}
}
return prob;
}
/**
* get the i'th tree of the trace
*
* @param index
* @return the i'th tree of the trace
*/
public final Tree getTree(int index) {
int oldTreeCount = 0;
int newTreeCount = 0;
for (TreeTrace trace : traces) {
newTreeCount += trace.getTreeCount(burnin * trace.getStepSize());
if (index < newTreeCount) {
return trace.getTree(index - oldTreeCount, burnin
* trace.getStepSize());
}
oldTreeCount = newTreeCount;
}
throw new RuntimeException("Couldn't find tree " + index);
}
/**
* increments the number of occurrences for all conditional clades
*
* @param tree - the tree to be added
*/
public void addTree(Tree tree) {
samples++;
List<Clade> clades = new ArrayList<Clade>();
List<Clade> parentClades = new ArrayList<Clade>();
// get clades contained in the tree
getClades(tree, tree.getRoot(), parentClades, clades);
// add the clade containing all taxa as well so that it get counted
clades.add(parentClades.get(parentClades.size() - 1));
parentClades.add(clades.get(clades.size() - 1));
int size = clades.size();
// for every clade multiply its conditional clade probability to the
// tree probability
for (int i = 0; i < size; i++) {
Clade c = clades.get(i);
// get the bits of the clade
Clade parent = parentClades.get(i);
HashMap<BitSet, Clade> coFreqs;
// increment the clade occurrences
if (cladeProbabilities.containsKey(c.getBits())) {
Clade tmp = cladeProbabilities.get(c.getBits());
// tmp.addSample();
tmp.addHeight(c.getHeight());
// add the amount to the current occurences
// frequency += cladeProbabilities.get(c);
} else {
// just to set the first value of the height value list
// c.addSample();
c.addHeight(c.getHeight());
cladeProbabilities.put(c.getBits(), c);
}
// increment the conditional clade occurrences
if (!parent.equals(c)) {
if (cladeCoProbabilities.containsKey(parent.getBits())) {
coFreqs = cladeCoProbabilities.get(parent.getBits());
} else {
// if it's the first time we observe the parent then we need
// a new list for its conditional clades
coFreqs = new HashMap<BitSet, Clade>();
cladeCoProbabilities.put(parent.getBits(), coFreqs);
}
// add the previous observed occurrences for this conditional
// clade
if (coFreqs.containsKey(c.getBits())) {
Clade tmp = coFreqs.get(c.getBits());
tmp.addHeight(c.getHeight());
// coFrequency += coFreqs.get(c.getBits());
} else {
// TODO check this code, especially if the cloning is needed
// and not just the clade could be added
Clade tmp = new Clade((BitSet) c.getBits().clone(), c.getHeight());
tmp.addHeight(c.getHeight());
coFreqs.put(c.getBits(), tmp);
}
}
}
}
/**
* increments the number of occurrences for all conditional clades
*
* @param tree - the tree to be added
*/
public void addTree(Tree tree, HashMap<String, Integer> taxonMap) {
samples++;
List<Clade> clades = new ArrayList<Clade>();
List<Clade> parentClades = new ArrayList<Clade>();
// get clades contained in the tree
getClades(tree, tree.getRoot(), parentClades, clades, taxonMap);
// add the clade containing all taxa as well so that it get counted
clades.add(parentClades.get(parentClades.size() - 1));
parentClades.add(clades.get(clades.size() - 1));
int size = clades.size();
// for every clade multiply its conditional clade probability to the
// tree probability
for (int i = 0; i < size; i++) {
Clade c = clades.get(i);
// get the bits of the clade
Clade parent = parentClades.get(i);
HashMap<BitSet, Clade> coFreqs;
// increment the clade occurrences
if (cladeProbabilities.containsKey(c.getBits())) {
Clade tmp = cladeProbabilities.get(c.getBits());
// tmp.addSample();
tmp.addHeight(c.getHeight());
// add the amount to the current occurences
// frequency += cladeProbabilities.get(c);
} else {
// just to set the first value of the height value list
// c.addSample();
c.addHeight(c.getHeight());
cladeProbabilities.put(c.getBits(), c);
}
// increment the conditional clade occurrences
if (!parent.equals(c)) {
if (cladeCoProbabilities.containsKey(parent.getBits())) {
coFreqs = cladeCoProbabilities.get(parent.getBits());
} else {
// if it's the first time we observe the parent then we need
// a new list for its conditional clades
coFreqs = new HashMap<BitSet, Clade>();
cladeCoProbabilities.put(parent.getBits(), coFreqs);
}
// add the previous observed occurrences for this conditional
// clade
if (coFreqs.containsKey(c.getBits())) {
Clade tmp = coFreqs.get(c.getBits());
tmp.addHeight(c.getHeight());
// coFrequency += coFreqs.get(c.getBits());
} else {
// TODO check this code, especially if the cloning is needed
// and not just the clade could be added
Clade tmp = new Clade((BitSet) c.getBits().clone(), c
.getHeight());
tmp.addHeight(c.getHeight());
coFreqs.put(c.getBits(), tmp);
}
}
}
}
/**
* @param reader the readers to be analyzed
* @param burnin the burnin in states
* @param verbose true if progress should be logged to stdout
* @return an analyses of the trees in a log file.
* @throws java.io.IOException if general I/O error occurs
*/
public static ConditionalCladeFrequency analyzeLogFile(Reader[] reader,
double e, int burnin, boolean verbose) throws IOException {
TreeTrace[] trace = new TreeTrace[reader.length];
for (int i = 0; i < reader.length; i++) {
try {
trace[i] = TreeTrace.loadTreeTrace(reader[i]);
} catch (Importer.ImportException ie) {
throw new RuntimeException(ie.toString());
}
reader[i].close();
}
return new ConditionalCladeFrequency(trace, e, burnin, verbose);
}
}