/*
* SubstitutionEpochModel.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.oldevomodel.substmodel;
import dr.evolution.datatype.DataType;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.matrixAlgebra.Vector;
import java.util.List;
/**
* @author Marc A. Suchard
*/
public class SubstitutionEpochModel extends AbstractSubstitutionModel {
public static final boolean DEBUG = false;
public SubstitutionEpochModel(String name,
List<SubstitutionModel> modelList,
Parameter transitionTimes,
DataType dataType, FrequencyModel freqModel) {
super(name, dataType, freqModel);
this.modelList = modelList;
this.transitionTimesParameter = transitionTimes;
this.transitionTimes = transitionTimesParameter.getParameterValues();
addVariable(transitionTimes);
for (SubstitutionModel model : modelList)
addModel(model);
numberModels = modelList.size();
weight = new double[numberModels];
stateCount = dataType.getStateCount();
stepMatrix = new double[stateCount * stateCount];
productMatrix = new double[stateCount * stateCount];
resultMatrix = new double[stateCount * stateCount];
}
protected void frequenciesChanged() {
}
protected void handleModelChangedEvent(Model model, Object object, int index) {
if (model == freqModel)
frequenciesChanged();
else // This is an epoch model and I need to pass the info on
fireModelChanged(object, index);
}
protected void ratesChanged() {
}
protected void setupRelativeRates() {
}
public void getTransitionProbabilities(double startTime, double endTime, double distance, double[] matrix) {
int matrixCount = 0;
// System.out.println("startTime " + startTime );
// System.out.println("endTime " + endTime );
// EpochBranchSubstitutionModel.printMatrix(resultMatrix);
boolean oneMatrix = (getEpochWeights(startTime, endTime, weight) == 1);
for (int m = 0; m < numberModels; m++) {
if (weight[m] > 0) {
SubstitutionModel model = modelList.get(m);
if (matrixCount == 0) {
if (oneMatrix) {
model.getTransitionProbabilities(distance, matrix);
break;
} else
model.getTransitionProbabilities(distance * weight[m], resultMatrix);
matrixCount++;
// System.out.println("first " + weight[m] * (endTime - startTime) + " " + model.getVariable(0).getValue(0) + model.getId());
// EpochBranchSubstitutionModel.printMatrix(resultMatrix);
} else {
model.getTransitionProbabilities(distance * weight[m], stepMatrix);
// Sum over unobserved state
int index = 0;
// System.out.println("startTime " + startTime );
// System.out.println("endTime " + endTime );
// System.out.println("second " + weight[m] * (endTime - startTime) + " " + model.getVariable(0).getValue(0) + model.getId());
// EpochBranchSubstitutionModel.printMatrix(stepMatrix);
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < stateCount; j++) {
productMatrix[index] = 0;
for (int k = 0; k < stateCount; k++) {
productMatrix[index] += resultMatrix[i * stateCount + k] * stepMatrix[k * stateCount + j];
}
index++;
}
}
// EpochBranchSubstitutionModel.printMatrix(productMatrix);
// Swap pointers
double[] tmpMatrix = resultMatrix;
resultMatrix = productMatrix;
productMatrix = tmpMatrix;
}
}
}
if (!oneMatrix)
System.arraycopy(productMatrix, 0, matrix, 0, stateCount * stateCount);
// System.out.println("C:");
// EpochBranchSubstitutionModel.printMatrix(resultMatrix);
// System.exit(-1);
}
private int getEpochWeights(double startTime, double endTime, double[] weights) {
int matrixCount = 0;
final double lengthTime = endTime - startTime;
final int lastTime = numberModels - 2;
// model 0, 1, 2, ..., K-2, K-1
// times 0, 1, ..., K-2,
// where K = numberModels
// System.out.println(lengthTime);
// First epoch: 0 -> transitionTimes[0];
if (startTime <= transitionTimes[0]) {
if (endTime <= transitionTimes[0])
weights[0] = 1;
else
weights[0] = (transitionTimes[0] - startTime) / lengthTime;
matrixCount++;
} else
weights[0] = 0;
// Middle epoches:
for (int i = 1; i <= lastTime; i++) {
if (startTime <= transitionTimes[i]) {
double start = Math.max(startTime, transitionTimes[i - 1]);
double end = Math.min(endTime, transitionTimes[i]);
weights[i] = (end - start) / lengthTime;
matrixCount++;
} else
weights[i] = 0;
}
// Last epoch: transitionTimes[K-2] -> Infinity
if (lastTime >= 0) {
if (endTime > transitionTimes[lastTime]) {
double start = Math.max(startTime, transitionTimes[lastTime]);
weights[lastTime + 1] = (endTime - start) / lengthTime;
matrixCount++;
} else
weights[lastTime + 1] = 0;
}
if (DEBUG) {
double totalWeight = 0;
for (int i = 0; i < numberModels; i++)
totalWeight += weights[i];
System.err.println("Start: " + startTime + " End: " + endTime + " Count: " + matrixCount + " Weight: " + totalWeight + " - " + new Vector(weights));
if (totalWeight > 1.001) System.exit(-1);
if (totalWeight < 0.999) System.exit(-1);
}
// System.out.println(endTime-startTime);
// System.out.println(matrixCount);
// EpochBranchSubstitutionModel.printArray(weights);
return matrixCount;
}
public void getTransitionProbabilities(double distance, double[] matrix) {
throw new RuntimeException("Should not get here in a substitution epoch model.");
}
protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
super.handleVariableChangedEvent(variable, index, type);
if (variable == transitionTimesParameter) {
transitionTimes = transitionTimesParameter.getParameterValues();
fireModelChanged(variable, index);
}
}
protected void storeState() {
}
protected void restoreState() {
}
private List<SubstitutionModel> modelList;
private Parameter transitionTimesParameter;
private double[] transitionTimes;
private double[] weight;
private double[] stepMatrix;
private double[] productMatrix;
private double[] resultMatrix;
private int numberModels;
private int stateCount;
}