/*
* StandarizeTraits.java
*
* Copyright (c) 2002-2017 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.continuous;
import dr.evolution.util.Taxa;
import dr.evolution.util.Taxon;
import dr.inference.model.MatrixParameterInterface;
import dr.math.Polynomial;
import dr.xml.*;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;
/**
* Created by msuchard on 2/3/17.
*/
public class StandardizeTraits {
public static final String STANDARDIZE_TRAITS = "standardizeTraits";
public static final String TRAIT_NAME = dr.evomodelxml.treelikelihood.TreeTraitParserUtilities.TRAIT_NAME;
private final MatrixParameterInterface matrix;
public StandardizeTraits(MatrixParameterInterface matrix) {
this.matrix = matrix;
// System.err.println("rows: " + matrix.getRowDimension());
// System.err.println("cols: " + matrix.getColumnDimension());
}
// private final Taxa taxa;
// private final String traitName;
//
// public StandardizeTraits(Taxa taxa, String traitName) {
// this.taxa = taxa;
// this.traitName = traitName;
// }
//
public String doStandardization(boolean byColumn) {
StringBuilder sb = new StringBuilder();
sb.append("Trait standardization report:\n");
final int dim = (byColumn ? matrix.getColumnDimension() : matrix.getRowDimension());
for (int index = 0; index < dim; ++index) {
MeanVariance mv = getStatistics(matrix, index, byColumn);
sb.append("\tBEFORE\n");
sb.append("\tindex: " + (index + 1) + "\n");
sb.append("\tmean : " + mv.mean + "\n");
sb.append("\tvar : " + mv.variance + "\n");
sb.append("\tcnt : " + mv.count + "\n");
updateValues(matrix, mv, index, byColumn);
mv = getStatistics(matrix, index, byColumn);
sb.append("\tAFTER\n");
sb.append("\tindex: " + (index + 1) + "\n");
sb.append("\tmean : " + mv.mean + "\n");
sb.append("\tvar : " + mv.variance + "\n");
sb.append("\tcnt : " + mv.count + "\n\n");
}
return sb.toString();
}
// public void doStandardization() throws Exception {
// List<Taxon> taxonList = taxa.asList();
//
//
// List<double[]> values = new ArrayList<double[]>();
//
// for (Taxon taxon : taxonList) {
// String attribute = (String) taxon.getAttribute(traitName);
// if (attribute == null) {
// throw new Exception("Taxon " + taxon.getId() + " does not contain trait " + traitName);
// }
// values.add(convert(attribute));
// }
//
// final int dim = values.get(0).length;
//
// for (int col = 0; col < dim; ++col) {
// standardize(values, col);
// }
//
// int index = 0;
// for (Taxon taxon : taxonList) {
// taxon.setAttribute(traitName, convert(values.get(index)));
// ++index;
// }
//// System.exit(-1);
// }
//
// private double[] convert(String string) {
// StringTokenizer st = new StringTokenizer(string);
// double[] values = new double[st.countTokens()];
//
// for (int i = 0; i < st.countTokens(); ++i) {
// String str = st.nextToken();
//
// double value = Double.NaN;
// if (str.equals("NA") || str.equals("?")) {
// // Do nothing
// } else {
// value = Double.valueOf(str);
// }
// values[i] = value;
// }
//
// return values;
// }
//
// private String convert(double[] values) {
// StringBuilder sb = new StringBuilder();
// for (double x : values) {
// if (!Double.isNaN(x)) {
// sb.append(x);
// } else {
// sb.append("NA");
// }
// sb.append(" ");
// }
// return sb.toString().trim();
// }
//
// private void standardize(List<double[]> values, int column) {
//
// MeanVariance mv = getStatistics(values, column);
// final double sd = Math.sqrt(mv.variance);
//
//
// System.err.println("Col: " + column);
// System.err.println("Mean: " + mv.mean);
// System.err.println("Var: " + mv.variance);
// System.err.println("Count:" + mv.count);
// System.err.println("");
//
// for (double[] row : values) {
// double x = row[column];
// if (!Double.isNaN(x)) {
// row[column] = (x - mv.mean) / sd;
// }
// }
// }
private void updateValues(MatrixParameterInterface matrix, final MeanVariance mv, int major, boolean byColumn) {
final int dim = (byColumn ? matrix.getRowDimension() : matrix.getColumnDimension());
final double sd = Math.sqrt(mv.variance);
for (int index = 0; index < dim; ++index) {
final int row = byColumn ? index : major;
final int col = byColumn ? major : index;
double x = matrix.getParameterValue(row, col);
if (!Double.isNaN(x)) {
x = (x - mv.mean) / sd;
matrix.setParameterValueQuietly(row, col, x);
}
}
matrix.fireParameterChangedEvent();
}
// private MeanVariance getStatistics(List<double[]> values, int column) {
//
// double s = 0.0;
// double ss = 0.0;
// int c = 0;
// for (double[] row : values) {
// double x = row[column];
// if (!Double.isNaN(x)) {
// s += x;
// ss += x * x;
// ++c;
// }
// }
//
// MeanVariance mv = new MeanVariance();
// mv.mean = s / c;
// mv.variance = ss / c - mv.mean * mv.mean;
// mv.count = c;
//
// return mv;
// }
private MeanVariance getStatistics(MatrixParameterInterface matrix, int major, boolean byColumn) {
double s = 0.0;
double ss = 0.0;
int c = 0;
final int dim = (byColumn ? matrix.getRowDimension() : matrix.getColumnDimension());
for (int index = 0; index < dim; ++index) {
double x = byColumn ? matrix.getParameterValue(index, major) : matrix.getParameterValue(major, index);
if (!Double.isNaN(x)) {
s += x;
ss += x * x;
++c;
}
}
MeanVariance mv = new MeanVariance();
mv.mean = s / c;
mv.variance = ss / c - mv.mean * mv.mean;
mv.count = c;
return mv;
}
private class MeanVariance {
double mean;
double variance;
int count;
}
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
// Taxa taxa = (Taxa) xo.getChild(Taxa.class);
// String traitName = (String) xo.getAttribute(TRAIT_NAME);
MatrixParameterInterface matrix = (MatrixParameterInterface) xo.getChild(MatrixParameterInterface.class);
StandardizeTraits st = new StandardizeTraits(matrix);
st.doStandardization(false);
// try {
// st.doStandardization();
// } catch (Exception e) {
// throw new XMLParseException(e.getMessage());
// }
// System.exit(-1);
return st;
}
/**
* @return an array of syntax rules required by this element.
* Order is not important.
*/
@Override
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
@Override
public String getParserDescription() {
return null;
}
@Override
public Class getReturnType() {
return StandardizeTraits.class;
}
/**
* @return Parser name, which is identical to name of xml element parsed by it.
*/
@Override
public String getParserName() {
return STANDARDIZE_TRAITS;
}
private final XMLSyntaxRule[] rules = new XMLSyntaxRule[] {
// AttributeRule.newStringRule(TRAIT_NAME),
// new ElementRule(Taxa.class),
new ElementRule(MatrixParameterInterface.class),
};
};
}