/*
* DiffusionRateStatistic.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.continuous;
import dr.app.util.Arguments;
import dr.evolution.tree.MultivariateTraitTree;
import dr.evolution.tree.NodeRef;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.tree.TreeStatistic;
import dr.geo.math.SphericalPolarCoordinates;
import dr.inference.model.Statistic;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.stats.DiscreteStatistics;
import dr.xml.*;
import java.util.*;
import java.io.IOException;
/**
* @author Marc Suchard
* @author Philippe Lemey
* @author Andrew Rambaut
*/
@Deprecated
public class DiffusionRateStatistic extends Statistic.Abstract {
public static final String DIFFUSION_RATE_STATISTIC = "diffusionRateStatistic";
public static final String TREE_DISPERSION_STATISTIC = "treeDispersionStatistic";
public static final String BOOLEAN_DIS_OPTION = "greatCircleDistance";
public static final String MODE = "mode";
public static final String MEDIAN = "median";
public static final String AVERAGE = "average"; // average over all branches
public static final String WEIGHTED_AVERAGE = "weightedAverage"; // weighted average (=total distance/total time)
public static final String COEFFICIENT_OF_VARIATION = "coefficientOfVariation"; // weighted average (=total distance/total time)
public static final String STATISTIC = "statistic";
public static final String DIFFUSION_RATE = "diffusionRate"; // weighted average (=total distance/total time)
public static final String WAVEFRONT_DISTANCE = "wavefrontDistance"; // weighted average (=total distance/total time)
public static final String WAVEFRONT_RATE = "wavefrontRate"; // weighted average (=total distance/total time)
public static final String DIFFUSION_COEFFICIENT = "diffusionCoefficient";
// public static final String DIFFUSIONCOEFFICIENT = "diffusionCoefficient"; // weighted average (=total distance/total time)
// public static final String BOOLEAN_DC_OPTION = "diffusionCoefficient";
public static final String HEIGHT_UPPER = "heightUpper";
public static final String HEIGHT_LOWER = "heightLower";
public static final String HEIGHT_LOWER_SERIE = "heightLowerSerie";
public static final String CUMULATIVE = "cumulative";
public DiffusionRateStatistic(String name, List<AbstractMultivariateTraitLikelihood> traitLikelihoods,
boolean option, Mode mode,
summaryStatistic statistic, double heightUpper, double heightLower,
double[] lowerHeights, boolean cumulative) {
super(name);
this.traitLikelihoods = traitLikelihoods;
this.useGreatCircleDistances = option;
summaryMode = mode;
summaryStat = statistic;
this.heightUpper = heightUpper;
if (lowerHeights == null){
heightLowers = new double[]{heightLower};
} else {
heightLowers = extractUnique(lowerHeights);
Arrays.sort(heightLowers);
reverse(heightLowers);
}
this.cumulative = cumulative;
}
public int getDimension() {
return heightLowers.length;
}
public double getStatisticValue(int dim) {
String traitName = traitLikelihoods.get(0).getTraitName();
double treelength = 0;
double treeDistance = 0;
double maxDistanceFromRoot = 0;
double maxDistanceOverTimeFromRoot = 0;
//double[] rates = null;
List<Double> rates = new ArrayList<Double>();
//double[] diffusionCoefficients = null;
List<Double> diffusionCoefficients = new ArrayList<Double>();
double waDiffusionCoefficient = 0;
double lowerHeight = heightLowers[dim];
double upperHeight = Double.MAX_VALUE;
if (heightLowers.length == 1){
upperHeight = heightUpper;
} else {
if (dim > 0) {
if (!cumulative) {
upperHeight = heightLowers[dim -1];
}
}
}
// System.out.println("dim = "+dim+", heightLower = "+lowerHeight+", heightUpper = "+upperHeight);
for (AbstractMultivariateTraitLikelihood traitLikelihood : traitLikelihoods) {
MultivariateTraitTree tree = traitLikelihood.getTreeModel();
BranchRateModel branchRates = traitLikelihood.getBranchRateModel();
for (int i = 0; i < tree.getNodeCount(); i++) {
NodeRef node = tree.getNode(i);
if (node != tree.getRoot()) {
NodeRef parentNode = tree.getParent(node);
if ((tree.getNodeHeight(parentNode) > lowerHeight) && (tree.getNodeHeight(node) < upperHeight)) {
double[] trait = traitLikelihood.getTraitForNode(tree, node, traitName);
double[] parentTrait = traitLikelihood.getTraitForNode(tree, parentNode, traitName);
double[] traitUp = parentTrait;
double[] traitLow = trait;
double timeUp = tree.getNodeHeight(parentNode);
double timeLow = tree.getNodeHeight(node);
double rate = (branchRates != null ? branchRates.getBranchRate(tree, node) : 1.0);
MultivariateDiffusionModel diffModel = traitLikelihood.diffusionModel;
double[] precision = diffModel.getPrecisionParameter().getParameterValues();
if (tree.getNodeHeight(parentNode) > upperHeight) {
timeUp = upperHeight;
//TODO: implement TrueNoise??
traitUp = imputeValue(trait, parentTrait, upperHeight, tree.getNodeHeight(node), tree.getNodeHeight(parentNode), precision, rate, false);
}
if (tree.getNodeHeight(node) < lowerHeight) {
timeLow = lowerHeight;
traitLow = imputeValue(trait, parentTrait, lowerHeight, tree.getNodeHeight(node), tree.getNodeHeight(parentNode), precision, rate, false);
}
double time = timeUp - timeLow;
treelength += time;
double[] rootTrait = traitLikelihood.getTraitForNode(tree, tree.getRoot(), traitName);
if (useGreatCircleDistances && (trait.length == 2)) { // Great Circle distance
SphericalPolarCoordinates coord1 = new SphericalPolarCoordinates(traitLow[0], traitLow[1]);
SphericalPolarCoordinates coord2 = new SphericalPolarCoordinates(traitUp[0], traitUp[1]);
double distance = coord1.distance(coord2);
treeDistance += distance;
double dc = Math.pow(distance,2)/(4*time);
diffusionCoefficients.add(dc);
waDiffusionCoefficient += dc*time;
rates.add(distance/time);
SphericalPolarCoordinates rootCoord = new SphericalPolarCoordinates(rootTrait[0], rootTrait[1]);
double tempDistanceFromRoot = rootCoord.distance(coord2);
if (tempDistanceFromRoot > maxDistanceFromRoot){
maxDistanceFromRoot = tempDistanceFromRoot;
maxDistanceOverTimeFromRoot = tempDistanceFromRoot/(tree.getNodeHeight(tree.getRoot()) - timeLow);
//distance between traitLow and traitUp for maxDistanceFromRoot
if (timeUp == upperHeight) {
maxDistanceFromRoot = distance;
maxDistanceOverTimeFromRoot = distance/time;
}
}
} else {
double distance = getNativeDistance(traitLow, traitUp);
treeDistance += distance;
double dc = Math.pow(distance,2)/(4*time);
diffusionCoefficients.add(dc);
waDiffusionCoefficient += dc*time;
rates.add(distance/time);
double tempDistanceFromRoot = getNativeDistance(traitLow, rootTrait);
if (tempDistanceFromRoot > maxDistanceFromRoot){
maxDistanceFromRoot = tempDistanceFromRoot;
maxDistanceOverTimeFromRoot = tempDistanceFromRoot/(tree.getNodeHeight(tree.getRoot()) - timeLow);
//distance between traitLow and traitUp for maxDistanceFromRoot
if (timeUp == upperHeight) {
maxDistanceFromRoot = distance;
maxDistanceOverTimeFromRoot = distance/time;
}
}
}
}
}
}
}
if (summaryStat == summaryStatistic.DIFFUSION_RATE){
if (summaryMode == Mode.AVERAGE) {
return DiscreteStatistics.mean(toArray(rates));
} else if (summaryMode == Mode.MEDIAN) {
return DiscreteStatistics.median(toArray(rates));
} else if (summaryMode == Mode.COEFFICIENT_OF_VARIATION) {
// don't compute mean twice
final double mean = DiscreteStatistics.mean(toArray(rates));
return Math.sqrt(DiscreteStatistics.variance(toArray(rates), mean)) / mean;
} else {
return treeDistance / treelength;
}
} else if (summaryStat == summaryStatistic.DIFFUSION_COEFFICIENT) {
if (summaryMode == Mode.AVERAGE) {
return DiscreteStatistics.mean(toArray(diffusionCoefficients));
} else if (summaryMode == Mode.MEDIAN) {
return DiscreteStatistics.median(toArray(diffusionCoefficients));
} else if (summaryMode == Mode.COEFFICIENT_OF_VARIATION) {
// don't compute mean twice
final double mean = DiscreteStatistics.mean(toArray(diffusionCoefficients));
return Math.sqrt(DiscreteStatistics.variance(toArray(diffusionCoefficients), mean)) / mean;
} else {
return waDiffusionCoefficient/treelength;
}
} else if (summaryStat == summaryStatistic.WAVEFRONT_DISTANCE) {
return maxDistanceFromRoot;
} else {
return maxDistanceOverTimeFromRoot;
}
}
// private double getNativeDistance(double[] location1, double[] location2) {
// return Math.sqrt(Math.pow((location2[0] - location1[0]), 2.0) + Math.pow((location2[1] - location1[1]), 2.0));
// }
private double getNativeDistance(double[] location1, double[] location2) {
int traitDimension = location1.length;
double sum = 0;
for (int i = 0; i < traitDimension; i++) {
sum += Math.pow((location2[i] - location1[i]),2);
// System.out.println(sum);
}
return Math.sqrt(sum);
}
private double[] toArray(List<Double> list) {
double[] returnArray = new double[list.size()];
for (int i = 0; i < list.size(); i++) {
returnArray[i] = Double.valueOf(list.get(i).toString());
}
return returnArray;
}
private double[] imputeValue(double[] nodeValue, double[] parentValue, double time, double nodeHeight, double parentHeight, double[] precisionArray, double rate, boolean trueNoise) {
final double scaledTimeChild = (time - nodeHeight) * rate;
final double scaledTimeParent = (parentHeight - time) * rate;
final double scaledWeightTotal = 1.0 / scaledTimeChild + 1.0 / scaledTimeParent;
final int dim = nodeValue.length;
double[][] precision = new double[dim][dim];
int counter = 0;
for (int a = 0; a < dim; a++){
for (int b = 0; b < dim; b++){
precision[a][b] = precisionArray[counter];
counter++ ;
}
}
if (scaledTimeChild == 0)
return nodeValue;
if (scaledTimeParent == 0)
return parentValue;
// Find mean value, weighted average
double[] mean = new double[dim];
double[][] scaledPrecision = new double[dim][dim];
for (int i = 0; i < dim; i++) {
mean[i] = (nodeValue[i] / scaledTimeChild + parentValue[i] / scaledTimeParent) / scaledWeightTotal;
if (trueNoise) {
for (int j = i; j < dim; j++)
scaledPrecision[j][i] = scaledPrecision[i][j] = precision[i][j] * scaledWeightTotal;
}
}
// System.out.print(time+"\t"+nodeHeight+"\t"+parentHeight+"\t"+scaledTimeChild+"\t"+scaledTimeParent+"\t"+scaledWeightTotal+"\t"+mean[0]+"\t"+mean[1]+"\t"+scaledPrecision[0][0]+"\t"+scaledPrecision[0][1]+"\t"+scaledPrecision[1][0]+"\t"+scaledPrecision[1][1]);
if (trueNoise) {
mean = MultivariateNormalDistribution.nextMultivariateNormalPrecision(mean, scaledPrecision);
}
// System.out.println("\t"+mean[0]+"\t"+mean[1]+"\r");
double[] result = new double[dim];
for (int i = 0; i < dim; i++)
result[i] = mean[i];
return result;
}
public static double[] parseVariableLengthDoubleArray(String inString) throws Arguments.ArgumentException {
List<Double> returnList = new ArrayList<Double>();
StringTokenizer st = new StringTokenizer(inString, ",");
while (st.hasMoreTokens()) {
try {
returnList.add(Double.parseDouble(st.nextToken()));
} catch (NumberFormatException e) {
throw new Arguments.ArgumentException();
}
}
if (returnList.size() > 0) {
double[] doubleArray = new double[returnList.size()];
for (int i = 0; i < doubleArray.length; i++)
doubleArray[i] = returnList.get(i);
return doubleArray;
}
return null;
}
@Override
public String getDimensionName(int dim) {
if (getDimension() == 1) {
return getStatisticName();
} else {
return getStatisticName() +".height"+ heightLowers[dim];
}
}
public static void reverse(double[] array) {
if (array == null) {
return;
}
int i = 0;
int j = array.length - 1;
double tmp;
while (j > i) {
tmp = array[j];
array[j] = array[i];
array[i] = tmp;
j--;
i++;
}
}
public static double[] extractUnique(double[] array){
Set<Double> tmp = new LinkedHashSet<Double>();
for (Double each : array) {
tmp.add(each);
}
double [] output = new double[tmp.size()];
int i = 0;
for (Double each : tmp) {
output[i++] = each;
}
return output;
}
enum Mode {
AVERAGE,
WEIGHTED_AVERAGE,
MEDIAN,
COEFFICIENT_OF_VARIATION
}
enum summaryStatistic {
DIFFUSION_RATE,
DIFFUSION_COEFFICIENT,
WAVEFRONT_DISTANCE,
WAVEFRONT_RATE,
}
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public String getParserName() {
return DIFFUSION_RATE_STATISTIC;
}
@Override
public String[] getParserNames() {
return new String[]{getParserName(), TREE_DISPERSION_STATISTIC};
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String name = xo.getAttribute(NAME, xo.getId());
boolean option = xo.getAttribute(BOOLEAN_DIS_OPTION, true); // Default value is true
Mode averageMode;
String mode = xo.getAttribute(MODE, WEIGHTED_AVERAGE);
if (mode.equals(AVERAGE)) {
averageMode = Mode.AVERAGE;
} else if (mode.equals(MEDIAN)) {
averageMode = Mode.MEDIAN;
} else if (mode.equals(COEFFICIENT_OF_VARIATION)) {
averageMode = Mode.COEFFICIENT_OF_VARIATION;
} else if (mode.equals(WEIGHTED_AVERAGE)) {
averageMode = Mode.WEIGHTED_AVERAGE;
} else {
System.err.println("Unknown mode: "+mode+". Reverting to weighted average");
averageMode = Mode.WEIGHTED_AVERAGE;
}
// boolean diffCoeff = xo.getAttribute(BOOLEAN_DC_OPTION, false); // Default value is false
summaryStatistic summaryStat;
String statistic = xo.getAttribute(STATISTIC, DIFFUSION_RATE);
if (statistic.equals(DIFFUSION_RATE)) {
summaryStat = summaryStatistic.DIFFUSION_RATE;
} else if (statistic.equals(WAVEFRONT_DISTANCE)) {
summaryStat = summaryStatistic.WAVEFRONT_DISTANCE;
} else if (statistic.equals(WAVEFRONT_RATE)) {
summaryStat = summaryStatistic.WAVEFRONT_RATE;
} else if (statistic.equals(DIFFUSION_COEFFICIENT)) {
summaryStat = summaryStatistic.DIFFUSION_COEFFICIENT;
} else {
System.err.println("Unknown statistic: "+statistic+". Reverting to diffusion rate");
summaryStat = summaryStatistic.DIFFUSION_COEFFICIENT;
}
final double upperHeight = xo.getAttribute(HEIGHT_UPPER, Double.MAX_VALUE);
final double lowerHeight = xo.getAttribute(HEIGHT_LOWER, 0.0);
double[] lowerHeights = null;
if (xo.hasAttribute(HEIGHT_LOWER_SERIE)){
String lowerHeightsString = xo.getStringAttribute(HEIGHT_LOWER_SERIE);
try {
lowerHeights = parseVariableLengthDoubleArray(lowerHeightsString);
} catch (Arguments.ArgumentException e) {
System.err.println("Error reading " + HEIGHT_LOWER_SERIE);
System.exit(1);
}
}
boolean cumulative = xo.getAttribute(CUMULATIVE, false);
List<AbstractMultivariateTraitLikelihood> traitLikelihoods = new ArrayList<AbstractMultivariateTraitLikelihood>();
for (int i = 0; i < xo.getChildCount(); i++) {
if (xo.getChild(i) instanceof AbstractMultivariateTraitLikelihood) {
AbstractMultivariateTraitLikelihood amtl = (AbstractMultivariateTraitLikelihood) xo.getChild(i);
traitLikelihoods.add(amtl);
}
}
return new DiffusionRateStatistic(name, traitLikelihoods, option, averageMode, summaryStat, upperHeight, lowerHeight, lowerHeights, cumulative);
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "A statistic that returns the average of the branch diffusion rates";
}
public Class getReturnType() {
return TreeStatistic.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{
AttributeRule.newStringRule(NAME, true),
AttributeRule.newBooleanRule(BOOLEAN_DIS_OPTION, true),
AttributeRule.newStringRule(MODE, true),
AttributeRule.newStringRule(STATISTIC,true),
AttributeRule.newDoubleRule(HEIGHT_UPPER, true),
AttributeRule.newDoubleRule(HEIGHT_LOWER, true),
AttributeRule.newStringRule(HEIGHT_LOWER_SERIE,true),
AttributeRule.newBooleanRule(CUMULATIVE, true),
new ElementRule(AbstractMultivariateTraitLikelihood.class, 1, Integer.MAX_VALUE),
};
};
private boolean useGreatCircleDistances;
private List<AbstractMultivariateTraitLikelihood> traitLikelihoods;
private Mode summaryMode;
private summaryStatistic summaryStat;
private double heightUpper;
private double[] heightLowers;
private boolean cumulative;
}