/*
* TreeTrait.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evolution.tree;
/**
* @author Andrew Rambaut
* @author Marc Suchard
* @author Alexei Drummond
* @version $Id$
*/
public interface TreeTrait<T> {
public enum Intent {
NODE,
BRANCH,
WHOLE_TREE
}
/**
* The human readable name of this trait
*
* @return the name
*/
String getTraitName();
/**
* Specifies whether this is a trait of the tree, the nodes or the branch
*
* @return Intent
*/
Intent getIntent();
/**
* Return a class object for the trait
*
* @return the class
*/
Class getTraitClass();
/**
* Returns the trait values for the given node. If this is a branch trait then
* it will be for the branch above the specified node (and may not be valid for
* the root). The array will be the length returned by getDimension().
*
* @param tree a reference to a tree
* @param node a reference to a node
* @return the trait value
*/
T getTrait(final Tree tree, final NodeRef node);
/**
* Get a string representations of the trait value.
*
* @param tree a reference to a tree
* @param node a reference to a node
* @return the trait string representation
*/
String getTraitString(final Tree tree, final NodeRef node);
/**
* Specifies whether this trait is loggable
*
* @return Intent
*/
boolean getLoggable();
/**
* Default behavior
*/
class DefaultBehavior {
public boolean getLoggable() {
return true;
}
public boolean getFormatAsArray() {
return false;
}
}
/**
* An abstract base class for Double implementations
*/
public abstract class D extends DefaultBehavior implements TreeTrait<Double> {
public Class getTraitClass() {
return Double.class;
}
public String getTraitString(Tree tree, NodeRef node) {
return formatTrait(getTrait(tree, node));
}
public static String formatTrait(Double value) {
if (value == null) {
return null;
}
return value.toString();
}
}
/**
* An abstract base class for Double implementations
*/
public abstract class I extends DefaultBehavior implements TreeTrait<Integer> {
public Class getTraitClass() {
return Integer.class;
}
public String getTraitString(Tree tree, NodeRef node) {
return formatTrait(getTrait(tree, node));
}
public static String formatTrait(Integer value) {
if (value == null) {
return null;
}
return value.toString();
}
}
/**
* An abstract base class for String implementations
*/
public abstract class S extends DefaultBehavior implements TreeTrait<String> {
public Class getTraitClass() {
return String.class;
}
public String getTraitString(Tree tree, NodeRef node) {
return getTrait(tree, node);
}
}
/**
* An abstract base class for double array implementations
*/
public abstract class DA extends DefaultBehavior implements TreeTrait<double[]> {
public Class getTraitClass() {
return double[].class;
}
public String getTraitString(Tree tree, NodeRef node) {
return formatTrait(getTrait(tree, node));
}
public static String formatTrait(double[] values) {
if (values == null || values.length == 0) return null;
if (values.length > 1) {
StringBuilder sb = new StringBuilder("{");
sb.append(values[0]);
for (int i = 1; i < values.length; i++) {
sb.append(",");
sb.append(values[i]);
}
sb.append("}");
return sb.toString();
} else {
return Double.toString(values[0]);
}
}
}
/**
* An abstract base class for String array implementations
*/
public abstract class SA extends DefaultBehavior implements TreeTrait<String[]> {
public Class getTraitClass() {
return String[].class;
}
public String getTraitString(Tree tree, NodeRef node) {
return formatTrait(getTrait(tree, node), getFormatAsArray());
}
public static String formatTrait(String[] values, boolean asArray) {
if (values == null || values.length == 0) return null;
if (values.length > 1 || asArray) {
StringBuilder sb = new StringBuilder("{");
sb.append(values[0]);
for (int i = 1; i < values.length; i++) {
sb.append(",");
sb.append(values[i]);
}
sb.append("}");
return sb.toString();
} else {
return values[0];
}
}
}
/**
* An abstract base class for int array implementations
*/
public abstract class IA extends DefaultBehavior implements TreeTrait<int[]> {
public Class getTraitClass() {
return int[].class;
}
public String getTraitString(Tree tree, NodeRef node) {
return formatTrait(getTrait(tree, node));
}
public static String formatTrait(int[] values) {
if (values == null || values.length == 0) return null;
if (values.length > 1) {
StringBuilder sb = new StringBuilder("{");
sb.append(values[0]);
for (int i = 1; i < values.length; i++) {
sb.append(",");
sb.append(values[i]);
}
sb.append("}");
return sb.toString();
} else {
return Integer.toString(values[0]);
}
}
}
/**
* An abstract wrapper class that sums a TreeTrait<T> over the entire tree
*/
public abstract class SumOverTree<T> extends DefaultBehavior implements TreeTrait<T> {
private static final String NAME_PREFIX = "sumOverTree_";
private final TreeTrait<T> base;
private final String name;
private final boolean includeExternalNodes;
private final boolean includeInternalNodes;
public SumOverTree(TreeTrait<T> base) {
this(NAME_PREFIX + base.getTraitName(), base);
}
public SumOverTree(String name, TreeTrait<T> base) {
this(name, base, true, true);
}
public SumOverTree(String name, TreeTrait<T> base, boolean includeExternalNodes, boolean includeInternalNodes) {
this.base = base;
this.name = name;
this.includeExternalNodes = includeExternalNodes;
this.includeInternalNodes = includeInternalNodes;
}
public String getTraitName() {
return name;
}
public Intent getIntent() {
return Intent.WHOLE_TREE;
}
public T getTrait(Tree tree, NodeRef node) {
T count = null;
if (includeExternalNodes) {
for (int i = 0; i < tree.getExternalNodeCount(); i++) {
count = addToMatrix(count, base.getTrait(tree, tree.getExternalNode(i)));
}
}
if (includeInternalNodes) {
for (int i = 0; i < tree.getInternalNodeCount(); i++) {
count = addToMatrix(count, base.getTrait(tree, tree.getInternalNode(i)));
}
}
return count;
}
public boolean getLoggable() {
return base.getLoggable();
}
protected abstract T addToMatrix(T total, T summant);
}
/**
* A wrapper class that sums a TreeTrait.DA over the entire tree
*/
public class SumOverTreeDA extends SumOverTree<double[]> {
public SumOverTreeDA(String name, TreeTrait<double[]> base,
boolean includeExternalNodes, boolean includeInternalNodes) {
super(name, base, includeExternalNodes, includeInternalNodes);
}
public SumOverTreeDA(String name, TreeTrait<double[]> base) {
super(name, base);
}
public SumOverTreeDA(TreeTrait<double[]> base) {
super(base);
}
public String getTraitString(Tree tree, NodeRef node) {
return DA.formatTrait(getTrait(tree, node));
}
public Class getTraitClass() {
return double[].class;
}
protected double[] addToMatrix(double[] total, double[] summant) {
return addToMatrixStatic(total, summant);
}
protected static double[] addToMatrixStatic(double[] total, double[] summant) {
if (summant == null) {
return total;
}
final int length = summant.length;
if (total == null) {
total = new double[length];
}
for (int i = 0; i < length; i++) {
total[i] += summant[i];
}
return total;
}
}
/**
* A wrapper class that sums a TreeTrait.D over the entire tree
*/
public class SumOverTreeD extends SumOverTree<Double> {
public SumOverTreeD(String name, TreeTrait<Double> base,
boolean includeExternalNodes, boolean includeInternalNodes) {
super(name, base, includeExternalNodes, includeInternalNodes);
}
public SumOverTreeD(String name, TreeTrait<Double> base) {
super(name, base);
}
public SumOverTreeD(TreeTrait<Double> base) {
super(base);
}
public String getTraitString(Tree tree, NodeRef node) {
return D.formatTrait(getTrait(tree, node));
}
public Class getTraitClass() {
return double[].class;
}
protected Double addToMatrix(Double total, Double summant) {
return addToMatrixStatic(total, summant);
}
protected static Double addToMatrixStatic(Double total, Double summant) {
if (summant == null) {
return total;
}
if (total == null) {
total = 0.0;
}
total += summant;
return total;
}
}
/**
* An abstract wrapper class that sums a TreeTrait.Array into a TreeTrait
*/
public abstract class SumAcrossArray<T, TA> extends DefaultBehavior implements TreeTrait<T> {
private TreeTrait<TA> base;
private String name;
public static final String NAME_PREFIX = "sumAcrossArray_";
public SumAcrossArray(TreeTrait<TA> base) {
this(NAME_PREFIX + base.getTraitName(), base);
}
public SumAcrossArray(String name, TreeTrait<TA> base) {
this.name = name;
this.base = base;
}
public String getTraitName() {
return name;
}
public Intent getIntent() {
return base.getIntent();
}
public T getTrait(Tree tree, NodeRef node) {
TA values = base.getTrait(tree, node);
if (values == null) {
return null;
}
return reduce(values);
}
public boolean getLoggable() {
return base.getLoggable();
}
protected abstract T reduce(TA values);
}
/**
* A wrapper class that sums a TreeTrait.DA into a TreeTrait.D
*/
public class SumAcrossArrayD extends SumAcrossArray<Double, double[]> {
public SumAcrossArrayD(String name, TreeTrait<double[]> base) {
super(name, base);
}
public SumAcrossArrayD(TreeTrait<double[]> base) {
super(base);
}
public Class getTraitClass() {
return Double.class;
}
protected Double reduce(double[] values) {
double total = 0.0;
for (double value : values) {
total += value;
}
return total;
}
public String getTraitString(Tree tree, NodeRef node) {
return D.formatTrait(getTrait(tree, node));
}
}
/**
* An abstract wrapper class that picks one entry out of TreeTrait<T> where T is an array
*/
public abstract class PickEntry<T, TA> extends DefaultBehavior implements TreeTrait<T> {
protected TreeTrait<TA> base;
private String name;
protected int index;
public PickEntry(TreeTrait<TA> base, int index) {
this(base.getTraitName() + "_" + (index + 1), base, index);
}
public PickEntry(String name, TreeTrait<TA> base, int index) {
this.name = name;
this.base = base;
// if (base.getTraitClass() != int[].class || base.getTraitClass() != double[].class) {
// throw new RuntimeException("Only supported for arrays");
// }
this.index = index;
}
public String getTraitName() {
return name;
}
public Intent getIntent() {
return base.getIntent();
}
}
public class PickEntryD extends PickEntry<Double, double[]> {
public PickEntryD(TreeTrait<double[]> base, int index) {
super(base, index);
}
public PickEntryD(String name, TreeTrait<double[]> base, int index) {
super(name, base, index);
}
public Class getTraitClass() {
return Double.class;
}
public Double getTrait(Tree tree, NodeRef node) {
return base.getTrait(tree, node)[index];
}
public String getTraitString(Tree tree, NodeRef node) {
return D.formatTrait(getTrait(tree, node));
}
}
public class PickEntryDAndScale extends PickEntryD {
public PickEntryDAndScale(TreeTrait<double[]> base, int index) {
super(base, index);
}
public PickEntryDAndScale(String name, TreeTrait<double[]> base, int index) {
super(name, base, index);
}
public Double getTrait(Tree tree, NodeRef node) {
return (base.getTrait(tree, node)[index]) / tree.getBranchLength(node);
}
}
public class PickEntryI extends PickEntry<Integer, int[]> {
public PickEntryI(TreeTrait<int[]> base, int index) {
super(base, index);
}
public PickEntryI(String name, TreeTrait<int[]> base, int index) {
super(name, base, index);
}
public Class getTraitClass() {
return Double.class;
}
public Integer getTrait(Tree tree, NodeRef node) {
return base.getTrait(tree, node)[index];
}
public String getTraitString(Tree tree, NodeRef node) {
return I.formatTrait(getTrait(tree, node));
}
}
/**
* An abstract wrapper class that filters a TreeTrait<T>
*/
public abstract class Filtered<T> extends DefaultBehavior implements TreeTrait<T> {
private static final String NAME_PREFIX = "filtered_";
private final TreeTrait<T> base;
private final String name;
private final TreeNodeFilter treeNodeFilter;
public Filtered(TreeTrait<T> base, TreeNodeFilter treeNodeFilter) {
this(NAME_PREFIX + base.getTraitName(), base, treeNodeFilter);
}
public Filtered(String name, TreeTrait<T> base, TreeNodeFilter treeNodeFilter) {
this.base = base;
this.name = name;
this.treeNodeFilter = treeNodeFilter;
}
public String getTraitName() {
return name;
}
public Intent getIntent() {
return base.getIntent();
}
public T getTrait(Tree tree, NodeRef node) {
T count = null;
if (getIntent() == Intent.WHOLE_TREE) {
for (int i = 0; i < tree.getNodeCount(); i++) {
NodeRef tmpNode = tree.getNode(i);
if (treeNodeFilter.includeNode(tree, tmpNode)) {
count = addToMatrix(count, base.getTrait(tree, tmpNode));
}
}
} else {
count = base.getTrait(tree, node);
if (!treeNodeFilter.includeNode(tree, node)) {
count = zero(count);
}
}
return count;
}
public boolean getLoggable() {
return base.getLoggable();
}
protected abstract T addToMatrix(T total, T summant);
protected abstract T zero(T copy);
}
public class FilteredD extends Filtered<Double> {
public FilteredD(TreeTrait<Double> base, TreeNodeFilter treeNodeFilter) {
super(base, treeNodeFilter);
}
public FilteredD(String name, TreeTrait<Double> base, TreeNodeFilter treeNodeFilter) {
super(name, base, treeNodeFilter);
}
@Override
protected Double addToMatrix(Double total, Double summant) {
return SumOverTreeD.addToMatrixStatic(total, summant);
}
protected Double zero(Double copy) {
return 0.0;
}
public Class getTraitClass() {
return Double.class;
}
public String getTraitString(Tree tree, NodeRef node) {
return D.formatTrait(getTrait(tree, node));
}
}
public class FilteredDA extends Filtered<double[]> {
public FilteredDA(TreeTrait<double[]> base, TreeNodeFilter treeNodeFilter) {
super(base, treeNodeFilter);
}
public FilteredDA(String name, TreeTrait<double[]> base, TreeNodeFilter treeNodeFilter) {
super(name, base, treeNodeFilter);
}
@Override
protected double[] addToMatrix(double[] total, double[] summant) {
return SumOverTreeDA.addToMatrixStatic(total, summant);
}
protected double[] zero(double[] copy) {
return new double[copy.length];
}
public Class getTraitClass() {
return double[].class;
}
public String getTraitString(Tree tree, NodeRef node) {
return DA.formatTrait(getTrait(tree, node));
}
}
}