/*
* MultivariateOUModel.java
*
* Copyright (c) 2002-2016 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.distribution;
import dr.oldevomodel.substmodel.PositiveDefiniteSubstitutionModel;
import dr.oldevomodel.substmodel.SubstitutionModel;
import dr.inference.model.MatrixParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Statistic;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.Vector;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;
/**
* @author Marc Suchard
*/
@Deprecated // GLM stuff is now in inference.glm - this is here for backwards compatibility temporarily
public class MultivariateOUModel extends GeneralizedLinearModel implements Statistic {
private SubstitutionModel Q;
private MatrixParameter gamma;
private double[] time;
private double[] deltaTime;
private double[] design;
private double[] W;
private double[] initialPriorMean;
private int K;
private int Ksquared;
private int numTimeSteps;
private double[][] GminusWGWt;
private MultivariateNormalDistribution initialPrior;
private boolean likelihoodKnown = false;
private boolean storedLikelihoodKnown;
private double logLikelihood;
private double storedLogLikelihood;
private boolean conditionalPrecisionKnown = false;
private boolean storedConditionPrecisionKnown;
private double[] storedWt;
private double[] Wt;
private double[] conditionPrecisionVector;
private double[] storedConditionPrecisionVector;
private int[] mapTime;
private List<Double> deltaTimeList; // todo could just use a Map<Double,Integer>
public MultivariateOUModel(SubstitutionModel substitutionModel, Parameter dependentParam,
MatrixParameter gamma, double[] time, double[] design) {
super(dependentParam);
this.Q = substitutionModel;
this.time = time;
this.design = design;
this.gamma = gamma;
K = substitutionModel.getDataType().getStateCount();
Ksquared = K * K;
W = new double[Ksquared];
initialPriorMean = new double[K]; // todo send this mean in constructor
StringBuffer sb = new StringBuffer("Constructing a multivariate OU model:\n");
sb.append("\tOutcome dimension = ");
sb.append(K);
Logger.getLogger("dr.inference.distribution").info(sb.toString());
setupTimes();
addVariable(gamma);
addModel(substitutionModel);
}
private void setupTimes() {
deltaTimeList = new ArrayList<Double>();
numTimeSteps = time.length / K - 1;
deltaTime = new double[numTimeSteps];
mapTime = new int[numTimeSteps];
double currentTime = time[0];
int index = 0;
for (int i = 0; i < numTimeSteps; i++) {
index += K;
deltaTime[i] = time[index] - currentTime;
currentTime = time[index];
if (!deltaTimeList.contains(deltaTime[i])) {
deltaTimeList.add(deltaTime[i]);
}
mapTime[i] = deltaTimeList.indexOf(deltaTime[i]);
((PositiveDefiniteSubstitutionModel) Q).addPrecalculatedTime(-deltaTime[i]); // todo get rid of negative sign
}
Logger.getLogger("dr.inference.distribution").info(
"\tTime increments: " + new Vector(deltaTime)
);
Wt = new double[Ksquared * deltaTimeList.size()];
storedWt = new double[Ksquared * deltaTimeList.size()];
conditionPrecisionVector = new double[Ksquared * deltaTimeList.size()];
storedConditionPrecisionVector = new double[Ksquared * deltaTimeList.size()];
calculateConditionPrecision();
}
private void calculateConditionPrecision() {
int index = 0;
double[] tempW = new double[Ksquared];
double[][] G = gamma.getParameterAsMatrix();
for (double deltaTime : deltaTimeList) {
Q.getTransitionProbabilities(-deltaTime, tempW);
System.arraycopy(tempW, 0, Wt, Ksquared * index, Ksquared);
double[][] WG = new double[K][K]; // needs to start with zeros
for (int i = 0; i < K; i++) {
for (int j = 0; j < K; j++) {
for (int k = 0; k < K; k++)
WG[i][j] += tempW[i * K + k] * G[k][j];
}
}
double[][] WGWt = new double[K][K]; // needs to start with zeros
for (int i = 0; i < K; i++) {
for (int j = 0; j < K; j++) {
for (int k = 0; k < K; k++)
WGWt[i][j] += WG[i][k] * tempW[j * K + k];
}
}
for (int i = 0; i < K; i++) {
for (int j = 0; j < K; j++)
WGWt[i][j] = G[i][j] - WGWt[i][j];
}
WGWt = new Matrix(WGWt).inverse().toComponents();
for (int i = 0; i < K; i++)
System.arraycopy(WGWt[i], 0, conditionPrecisionVector, Ksquared * index + K * i, K);
index++;
}
conditionalPrecisionKnown = true;
}
public final double getLogLikelihood() {
if (!likelihoodKnown) {
logLikelihood = calculateLogLikelihood();
likelihoodKnown = true;
}
return logLikelihood;
}
public double calculateLogLikelihood(double[] x) {
return calculateLogLikelihood();
}
public double calculateLogLikelihood() {
double logLikelihood = 0;
double[] previous = new double[K];
double[] current = new double[K];
double[] tmpHolder;
double[][] G = gamma.getParameterAsMatrix();
double[] theta = dependentParam.getParameterValues();
double[] Xbeta = null;
boolean hasEffects = getNumberOfFixedEffects() > 0;
if (!conditionalPrecisionKnown)
calculateConditionPrecision();
// Prior on initial time-point
try {
if (new Matrix(G).determinant() < 0.01)
return Double.NEGATIVE_INFINITY;
} catch (IllegalDimension illegalDimension) {
illegalDimension.printStackTrace();
}
int index = 0;
if (!hasEffects) {
for (int i = 0; i < K; i++)
previous[i] = theta[index++];
} else {
Xbeta = getXBeta();
for (int i = 0; i < K; i++) {
previous[i] = theta[index] - Xbeta[index];
index++;
}
}
initialPrior = new MultivariateNormalDistribution(initialPriorMean, new Matrix(G).inverse().toComponents());
logLikelihood += initialPrior.logPdf(previous);
double save = logLikelihood;
double save2 = 0;
int oldMapTime = -1;
double[][] conditionalPrecision = new double[K][K];
for (int timeStep = 0; timeStep < numTimeSteps; timeStep++) {
int thisMapTime = mapTime[timeStep];
if (thisMapTime != oldMapTime) {
System.arraycopy(Wt, Ksquared * thisMapTime, W, 0, Ksquared);
for (int i = 0; i < K; i++)
System.arraycopy(conditionPrecisionVector, Ksquared * thisMapTime + K * i, conditionalPrecision[i], 0, K);
oldMapTime = thisMapTime;
}
double[] mean = new double[K];
int u = 0;
for (int i = 0; i < K; i++) {
for (int j = 0; j < K; j++)
mean[i] += W[u++] * previous[j];
}
// // start of removable part;
// double[][] WG = new double[K][K];
// for (int i = 0; i < K; i++) {
// for (int j = 0; j < K; j++) {
// for (int k = 0; k < K; k++)
// WG[i][j] += W[i * K + k] * G[k][j];
// }
// }
//
//
// double[][] WGWt = new double[K][K];
// for (int i = 0; i < K; i++) {
// for (int j = 0; j < K; j++) {
// for (int k = 0; k < K; k++)
// WGWt[i][j] += WG[i][k] * W[j * K + k];
// }
// }
//
//
// for (int i = 0; i < K; i++) {
// for (int j = 0; j < K; j++)
// WGWt[i][j] = G[i][j] - WGWt[i][j];
//
// }
//
// double[][] oldPrecision = new Matrix(WGWt).inverse().toComponents();
//
//
// GminusWGWt = WGWt;
// calculate density of current time step
// end of removable part;
if (!hasEffects) {
for (int i = 0; i < K; i++)
current[i] = theta[index++];
} else {
for (int i = 0; i < K; i++) {
current[i] = theta[index] - Xbeta[index];
index++;
}
}
MultivariateNormalDistribution density = new MultivariateNormalDistribution(
mean, conditionalPrecision);
double partialLogLikelihood = density.logPdf(current);
if (partialLogLikelihood > 10) {
return Double.NEGATIVE_INFINITY;
}
logLikelihood += partialLogLikelihood;
// move to next point
tmpHolder = previous;
previous = current;
current = tmpHolder;
}
if (logLikelihood > 100) {
System.err.println("got here end");
System.err.println("save1 = " + save);
System.err.println("save2 = " + save2);
System.exit(-1);
}
likelihoodKnown = true;
return logLikelihood;
}
protected boolean confirmIndependentParameters() {
return true;
}
public boolean requiresScale() {
return true;
}
protected void handleModelChangedEvent(Model model, Object object, int index) {
conditionalPrecisionKnown = false;
likelihoodKnown = false;
}
protected void handleParameterChangedEvent(Parameter parameter, int index) {
if (parameter == gamma) {
conditionalPrecisionKnown = false;
}
likelihoodKnown = false;
}
protected void storeState() {
System.arraycopy(Wt, 0, storedWt, 0, Wt.length);
System.arraycopy(conditionPrecisionVector, 0, storedConditionPrecisionVector, 0, conditionPrecisionVector.length);
storedLogLikelihood = logLikelihood;
storedLikelihoodKnown = likelihoodKnown;
storedConditionPrecisionKnown = conditionalPrecisionKnown;
}
protected void restoreState() {
double[] holder = Wt;
Wt = storedWt;
storedWt = holder;
holder = conditionPrecisionVector;
conditionPrecisionVector = storedConditionPrecisionVector;
storedConditionPrecisionVector = holder;
logLikelihood = storedLogLikelihood;
likelihoodKnown = storedLikelihoodKnown;
conditionalPrecisionKnown = storedConditionPrecisionKnown;
}
protected void acceptState() {
}
public String getStatisticName() {
return getId();
}
public String getDimensionName(int dim) {
return getId() + dim;
}
public void setDimensionNames(String[] names) {
// do nothing
}
public int getDimension() {
return W.length;
}
public double getStatisticValue(int dim) {
// int x = dim / K;
// int y = dim - x * K;
// if( GminusWGWt != null )
// return GminusWGWt[x][y];
if (W != null)
return W[dim];
return 0;
}
@Override
public double getValueSum() {
double sum = 0.0;
for (int i = 0; i < getDimension(); i++) {
sum += getStatisticValue(i);
}
return sum;
}
public String getAttributeName() {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
public double[] getAttributeValue() {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
}