/*
* MarkovJumpsSubstitutionModel.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.substmodel;
import dr.evolution.datatype.DataType;
import dr.inference.markovjumps.MarkovJumpsCore;
import dr.inference.markovjumps.MarkovJumpsType;
import dr.inference.markovjumps.StateHistory;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
/**
* @author Marc Suchard
* @author Vladimir Minin
* <p/>
* A base class for implementing Markov chain-induced counting processes (markovjumps) in BEAST using BEAGLE
* This work is supported by NSF grant 0856099
* <p/>
* Minin VN and Suchard MA (2008) Counting labeled transitions in continous-time Markov models of evolution.
* Journal of Mathematical Biology, 56, 391-412.
*/
public class MarkovJumpsSubstitutionModel extends AbstractModel {
public MarkovJumpsSubstitutionModel(SubstitutionModel substModel) {
this(substModel, MarkovJumpsType.COUNTS);
}
public MarkovJumpsSubstitutionModel(SubstitutionModel substModel, MarkovJumpsType type) {
super(substModel.getModelName());
this.substModel = substModel;
this.eigenDecomposition = substModel.getEigenDecomposition();
stateCount = substModel.getDataType().getStateCount();
markovJumpsCore = new MarkovJumpsCore(stateCount);
this.type = type;
setupStorage();
addModel(substModel);
dataType = substModel.getDataType();
}
protected void setupStorage() {
rateMatrix = new double[stateCount * stateCount];
transitionProbs = new double[stateCount * stateCount];
rateReg = new double[stateCount * stateCount];
if (PRECOMPUTE) {
ievcRateRegEvec = new double[stateCount * stateCount];
tmp1 = new double[stateCount * stateCount];
}
registration = new double[stateCount * stateCount];
reward = new double[stateCount];
}
public MarkovJumpsType getType() {
return type;
}
public void setRegistration(double[] inRegistration) {
if (type == MarkovJumpsType.COUNTS || type == MarkovJumpsType.HISTORY) {
System.arraycopy(inRegistration, 0, registration, 0, stateCount * stateCount);
for (int i = 0; i < stateCount; i++) {
registration[i * stateCount + i] = 0; // diagonals are zero
}
} else if (type == MarkovJumpsType.REWARDS) {
int index = 0;
for (int i = 0; i < stateCount; i++) {
reward[i] = inRegistration[i];
for (int j = 0; j < stateCount; j++) {
if (i == j) {
registration[index] = inRegistration[i];
} else {
registration[index] = 0; // Off-diagonals are zero
}
index++;
}
}
} else {
throw new RuntimeException("Unknown expectation type in MarkovJumps");
}
regRateChanged = true;
}
public double[] getRegistration() {
return registration;
}
private void makeRateRegistrationMatrix(double[] registration,
double[] rateReg,
double[] ievcRateRegEvec) {
if (type == MarkovJumpsType.COUNTS || type == MarkovJumpsType.HISTORY) {
substModel.getInfinitesimalMatrix(rateMatrix);
int index = 0;
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < stateCount; j++) {
rateReg[index] = rateMatrix[index] * registration[index];
index++;
}
}
} else if (type == MarkovJumpsType.REWARDS) {
System.arraycopy(registration, 0, rateReg, 0, stateCount * stateCount);
} else {
throw new RuntimeException("Unknown expectation type in MarkovJumps");
}
if (PRECOMPUTE) {
// matrixMultiply(rateReg, evec, stateCount, tmp1);
// matrixMultiply(ievc, tmp1, stateCount, tmp2);
MarkovJumpsCore.matrixMultiply(rateReg, eigenDecomposition.getEigenVectors(),
stateCount, tmp1);
MarkovJumpsCore.matrixMultiply(eigenDecomposition.getInverseEigenVectors(), tmp1,
stateCount, ievcRateRegEvec);
}
regRateChanged = false;
}
public double getMarginalRate() {
if (regRateChanged) {
makeRateRegistrationMatrix(registration, rateReg, ievcRateRegEvec);
}
FrequencyModel freqModel = substModel.getFrequencyModel();
double rate = 0;
int index = 0;
for (int i = 0; i < stateCount; i++) {
double freq_i = freqModel.getFrequency(i);
for (int j = 0; j < stateCount; j++) {
rate += freq_i * rateReg[index++];
}
}
return rate;
}
public void computeCondStatMarkovJumps(double time,
double[] countMatrix) {
substModel.getTransitionProbabilities(time, transitionProbs);
computeCondStatMarkovJumps(time, transitionProbs, countMatrix);
}
public double getProcessForSimulant(StateHistory history) {
final double total;
if (type == MarkovJumpsType.COUNTS || type == MarkovJumpsType.HISTORY) {
total = history.getTotalRegisteredCounts(registration);
} else {
total = history.getTotalReward(reward);
}
return total;
}
public void computeCondStatMarkovJumps(double time,
double[] transitionProbs,
double[] countMatrix) {
if (regRateChanged) {
makeRateRegistrationMatrix(registration, rateReg, ievcRateRegEvec);
}
double[] evec = eigenDecomposition.getEigenVectors();
double[] ievc = eigenDecomposition.getInverseEigenVectors();
double[] eval = eigenDecomposition.getEigenValues();
if (PRECOMPUTE) {
markovJumpsCore.computeCondStatMarkovJumpsPrecompute(
evec, ievc, eval, ievcRateRegEvec, time, transitionProbs, countMatrix);
} else {
markovJumpsCore.computeCondStatMarkovJumps(evec, ievc, eval, rateReg, time, transitionProbs, countMatrix);
}
}
public void computeJointStatMarkovJumps(double time,
double[] countMatrix) {
if (regRateChanged) {
makeRateRegistrationMatrix(registration, rateReg, ievcRateRegEvec);
}
double[] evec = eigenDecomposition.getEigenVectors();
double[] ievc = eigenDecomposition.getInverseEigenVectors();
double[] eval = eigenDecomposition.getEigenValues();
if (PRECOMPUTE) {
markovJumpsCore.computeJointStatMarkovJumpsPrecompute(evec, ievc, eval, ievcRateRegEvec, time, countMatrix);
} else {
markovJumpsCore.computeJointStatMarkovJumps(evec, ievc, eval, rateReg, time, countMatrix);
}
}
public SubstitutionModel getSubstitutionModel() {
return substModel;
}
protected void handleModelChangedEvent(Model model, Object object, int index) {
if (model == substModel) {
regRateChanged = true;
}
}
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
// Do nothing
}
protected void storeState() {
// Do nothing
}
protected void restoreState() {
// Do nothing
}
protected void acceptState() {
// Do nothing
}
public int stateCount;
private double[] rateReg;
private double[] ievcRateRegEvec;
private double[] tmp1;
private double[] transitionProbs;
private double[] rateMatrix;
protected double[] reward;
protected double[] registration;
protected SubstitutionModel substModel;
private EigenDecomposition eigenDecomposition;
private MarkovJumpsCore markovJumpsCore;
private boolean regRateChanged = true;
protected MarkovJumpsType type;
protected DataType dataType;
private static final boolean PRECOMPUTE = true;
}