/*
* DesignMatrix.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.model;
import dr.stats.DiscreteStatistics;
import dr.xml.*;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
/**
* @author Marc Suchard
*/
public class DesignMatrix extends MatrixParameter {
public static final String DESIGN_MATRIX = "designMatrix";
public static final String ADD_INTERCEPT = "addIntercept";
public static final String FORM = "form";
public static final String ROW_DIMENSION = "rowDimension";
public static final String COL_DIMENSION = "colDimension";
public static final String CHECK_IDENTIFABILITY = "checkIdentifiability";
public static final String STANDARDIZE = "standardize";
public static final String DYNAMIC_STANDARDIZATION = "dynamicStandardization";
public static final String INTERCEPT = "intercept";
public DesignMatrix(String name, boolean dynamicStandardization) {
super(name);
this.dynamicStandardization = dynamicStandardization;
init();
}
public void variableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
super.variableChangedEvent(variable, index, type);
standardizationKnown = false;
}
public double getParameterValue(int row, int col) {
double value = super.getParameterValue(row, col);
if (dynamicStandardization) {
if (!standardizationKnown) {
computeStandarization();
standardizationKnown = true;
}
value = (value - standardizationMean[col]) / standardizationStDev[col];
}
return value;
}
// public double getParameterValue(int index) {
// throw new RuntimeException("Univariate value from a design matrix");
// }
public void addParameter(Parameter param) {
super.addParameter(param);
clearCache(); // Changed size
}
public void removeParameter(Parameter param) {
super.removeParameter(param);
clearCache(); // Changed size
}
private void clearCache() {
standardizationMean = null;
standardizationStDev = null;
storedStandardizationMean = null;
storedStandardizationStDev = null;
}
private void computeStandarization() {
if (standardizationMean == null) {
standardizationMean = new double[getColumnDimension()];
}
if (standardizationStDev == null) {
standardizationStDev = new double[getColumnDimension()];
}
for (int col = 0; col < getColumnDimension(); col++) {
if ((getParameter(col).getId()).toLowerCase().indexOf(INTERCEPT) >= 0) {
standardizationMean[col] = 0.0;
standardizationStDev[col] = 1.0;
} else {
double[] vector = getParameter(col).getParameterValues();
standardizationMean[col] = DiscreteStatistics.mean(vector);
standardizationStDev[col] = Math.sqrt(DiscreteStatistics.variance(vector, standardizationMean[col]));
}
}
}
protected void storeValues() {
super.storeValues();
if (dynamicStandardization) {
if (storedStandardizationMean == null) {
storedStandardizationMean = new double[standardizationMean.length];
}
System.arraycopy(standardizationMean, 0, storedStandardizationMean, 0, standardizationMean.length);
if (storedStandardizationStDev == null) {
storedStandardizationStDev = new double[standardizationStDev.length];
}
System.arraycopy(standardizationStDev, 0, storedStandardizationStDev, 0, standardizationStDev.length);
}
}
protected void restoreValues() {
super.restoreValues();
if (dynamicStandardization) {
double[] tmp = standardizationMean;
standardizationMean = storedStandardizationMean;
storedStandardizationMean = tmp;
tmp = standardizationStDev;
standardizationStDev = storedStandardizationStDev;
storedStandardizationStDev = tmp;
}
}
public DesignMatrix(String name, Parameter[] parameters, boolean dynamicStandardization) {
super(name, parameters);
this.dynamicStandardization = dynamicStandardization;
init();
}
private void init() {
standardizationKnown = false;
}
// **************************************************************
// XMLElement IMPLEMENTATION
// **************************************************************
public Element createElement(Document d) {
throw new RuntimeException("Not implemented yet!");
}
public static void standardize(double[] vector) {
double mean = DiscreteStatistics.mean(vector);
double stDev = Math.sqrt(DiscreteStatistics.variance(vector, mean));
for (int i = 0; i < vector.length; ++i) {
vector[i] = (vector[i] - mean) / stDev;
}
}
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public String getParserName() {
return DESIGN_MATRIX;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
boolean dynamicStandardization = xo.getAttribute(DYNAMIC_STANDARDIZATION, false);
String name = (xo.hasId() ? xo.getId() : DESIGN_MATRIX);
DesignMatrix designMatrix = new DesignMatrix(name, dynamicStandardization);
boolean addIntercept = xo.getAttribute(ADD_INTERCEPT, false);
boolean standardize = xo.getAttribute(STANDARDIZE, false);
int dim = 0;
if (xo.hasAttribute(FORM)) {
String type = xo.getStringAttribute(FORM);
if (type.compareTo("J") == 0) {
int rowDim = xo.getAttribute(ROW_DIMENSION, 1);
int colDim = xo.getAttribute(COL_DIMENSION, 1);
for (int i = 0; i < colDim; i++) {
Parameter parameter = new Parameter.Default(rowDim);
designMatrix.addParameter(parameter);
}
} else
throw new XMLParseException("Unknown designMatrix form.");
} else {
for (int i = 0; i < xo.getChildCount(); i++) {
Parameter parameter = (Parameter) xo.getChild(i);
designMatrix.addParameter(parameter);
if (i == 0)
dim = parameter.getDimension();
else if (dim != parameter.getDimension())
throw new XMLParseException("Parameter " + (i+1) +" has dimension "+ parameter.getDimension()+ " and not "+dim+". "+
"All parameters must have the same dimension to construct a rectangular design matrix");
}
}
if (standardize) {
// Standardize all covariates except intercept
for (int j = 0; j < designMatrix.getColumnDimension(); ++j) {
Parameter columnParameter = designMatrix.getParameter(j);
double[] column = columnParameter.getParameterValues();
standardize(column);
for (int i = 0; i < column.length; ++i) {
columnParameter.setParameterValueQuietly(i, column[i]);
}
columnParameter.setParameterValueNotifyChangedAll(0, columnParameter.getParameterValue(0));
}
}
if (addIntercept) {
Parameter intercept = new Parameter.Default(dim);
intercept.setId(INTERCEPT);
designMatrix.addParameter(intercept);
}
return designMatrix;
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "A matrix parameter constructed from its component parameters.";
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
AttributeRule.newBooleanRule(ADD_INTERCEPT, true),
AttributeRule.newBooleanRule(CHECK_IDENTIFABILITY, true),
new ElementRule(Parameter.class, 0, Integer.MAX_VALUE), // TODO or have the following
AttributeRule.newStringRule(FORM, true), // TODO Should have to include both FORM and DIMENSION at the same time
AttributeRule.newIntegerRule(COL_DIMENSION, true),
AttributeRule.newIntegerRule(ROW_DIMENSION, true),
AttributeRule.newBooleanRule(STANDARDIZE, true),
};
public Class getReturnType() {
return DesignMatrix.class;
}
};
private final boolean dynamicStandardization;
private boolean standardizationKnown = false;
private double[] standardizationMean = null;
private double[] standardizationStDev = null;
private double[] storedStandardizationMean = null;
private double[] storedStandardizationStDev = null;
}