/*
* SubordinatedProcess.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.markovjumps;
import dr.math.GammaFunction;
import dr.math.MathUtils;
import dr.math.matrixAlgebra.Vector;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* A class to represent a Poisson process and discrete-time Markov chain that subordinate
* a continuous-time Markov chain in the interval [0,T]. The subordinator drives the Uniformization method
* for simulating end-conditioned realizations
* <p/>
* This work is supported by NSF grant 0856099
* <p/>
* Hobolth A and Stone E (2009) Simulation from endpoint-conditioned, continuous-time Markov chains on a finite
* state space, with applications to molecular evolution. Annals of Applied Statistics, 3, 1204-1231.
*
* @author Marc A. Suchard
*/
public class SubordinatedProcess {
public SubordinatedProcess(double[] Q, int stateCount) {
this.stateCount = stateCount;
poissonRate = getMaxRate(Q, stateCount);
dtmcCache = new ArrayList<double[]>();
dtmcCache.add(makeIndentityMatrx(stateCount));
dtmcCache.add(constructDtmcMatrix(Q, stateCount));
tmp = new double[stateCount];
this.Q = Q;
}
public double getPoissonRate() {
return poissonRate;
}
private double getCachedExp(double x) {
if (x != cachedXForExp) {
cachedXForExp = x;
cachedExpValue = Math.exp(x);
}
return cachedExpValue;
}
/**
* Compute the n-step discrete-time transition probabilities
*
* @param nSteps which step
* @return a pointer to the cached matrix
*/
public double[] getDtmcProbabilities(int nSteps) {
if (nSteps > dtmcCache.size() - 1) {
double[] dtmcOneStep = dtmcCache.get(1);
for (int step = dtmcCache.size() - 1; step <= nSteps; step++) {
double[] lastDtmcMatrix = dtmcCache.get(step);
double[] nextDtmcMatrix = new double[stateCount * stateCount];
MarkovJumpsCore.matrixMultiply(lastDtmcMatrix, dtmcOneStep, stateCount, nextDtmcMatrix);
dtmcCache.add(nextDtmcMatrix);
}
}
return dtmcCache.get(nSteps);
}
/**
* Find max_i -Q_{ii}
*
* @param Q ctmc rate matrix
* @param stateCount dim
* @return max rate
*/
private double getMaxRate(double[] Q, int stateCount) {
double max = -Q[0];
for (int i = 1; i < stateCount; i++) {
double nextRate = -Q[i * stateCount + i];
if (nextRate > max) {
max = nextRate;
}
}
return max;
}
/**
* R = I + 1/maxRate Q
*
* @param lambda Q
* @param stateCount dim
* @return R
*/
private double[] constructDtmcMatrix(double[] lambda, int stateCount) {
double[] R = new double[stateCount * stateCount];
double maxRate = getMaxRate(lambda, stateCount);
int index = 0;
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < stateCount; j++) {
R[index] = lambda[index] / maxRate;
if (i == j) {
R[index] += 1;
}
index++;
}
}
return R;
}
/**
* Simulate transition times, uniformly distributed before sorting
*
* @param timeDuration T
* @param totalNumberOfChanges total number of changes
* @return the transition times of the subordinated process
*/
public double[] drawTransitionTimes(double timeDuration, int totalNumberOfChanges) {
double[] times = new double[totalNumberOfChanges];
for (int i = 0; i < totalNumberOfChanges; i++) {
times[i] = timeDuration * MathUtils.nextDouble();
}
if (times.length > 1) {
Arrays.sort(times);
}
return times;
}
/**
* Simulate the next transition in the subordinated process, equation in remark 7
*
* @param currentState current state of the subordinated process
* @param endingState ending state of CTMC
* @param totalNumberOfChanges number of subordinated changes
* @param thisChangeNumber this transition number
* @return the next state of the subordinated process
*/
public int drawNextChainState(int currentState, int endingState, int totalNumberOfChanges, int thisChangeNumber) {
computePdfNextChainState(currentState, endingState, totalNumberOfChanges, thisChangeNumber, tmp);
return MathUtils.randomChoicePDF(tmp);
}
public void computePdfNextChainState(int currentState, int endingState, int totalNumberOfChanges, int thisChangeNumber,
double[] pdf) {
double[] R = getDtmcProbabilities(1);
double[] RnMinusI = getDtmcProbabilities(totalNumberOfChanges - thisChangeNumber);
for (int i = 0; i < stateCount; i++) {
pdf[i] = R[currentState * stateCount + i] * RnMinusI[i * stateCount + endingState];
// / RnMinusIPlus1[currentState * stateCount + endingState] // No need to normalize
}
}
public class Exception extends java.lang.Exception {
// Nothing special
}
/**
* Simulate the number of transitions in the subordinated process, equation (2.9)
*
* @param startingState starting state of CTMC
* @param endingState ending state of CTMC
* @param time length of chain
* @param ctmcProbability the CTMC finite-time transition probability
* @return the number of transitions in the subordinated process
* @throws dr.inference.markovjumps.SubordinatedProcess.Exception exception
*/
public int drawNumberOfChanges(int startingState, int endingState, double time, double ctmcProbability) throws SubordinatedProcess.Exception {
return drawNumberOfChanges(startingState, endingState, time, ctmcProbability, MathUtils.nextDouble());
}
public int drawNumberOfChanges(int startingState, int endingState, double time, double ctmcProbability,
double cutoff) throws SubordinatedProcess.Exception {
int drawnNumber = -1;
double cdf = 0;
double effectiveRate = getPoissonRate() * time;
double preFactor = getCachedExp(-effectiveRate);
double scale = 1.0;
int index = startingState * stateCount + endingState;
double[] check;
int maxTries = 1000;
if (DEBUG) {
check = new double[maxTries+1];
}
while (cutoff >= cdf) {
drawnNumber++;
double[] Rn = getDtmcProbabilities(drawnNumber);
if (drawnNumber > 0) {
scale *= effectiveRate;
}
if (drawnNumber > 1) {
scale /= (double) drawnNumber;
}
cdf += preFactor * scale * Rn[index] / ctmcProbability;
if (THROW_EXCEPTION) {
if (drawnNumber == maxTries) {
throw new SubordinatedProcess.Exception();
}
}
if (DEBUG) {
check[drawnNumber] = cdf;
if (drawnNumber == maxTries) {
System.err.println("Start state = " + startingState);
System.err.println("End state = " + endingState);
System.err.println("Time = " + time);
System.err.println("CDF = " + cdf);
System.err.println("Cutoff = " + cutoff);
System.err.println("CTMC prob = " + ctmcProbability);
System.err.println("PoissonRate = " + getPoissonRate());
double[] distr = computePDFDirectly(startingState, endingState, time, ctmcProbability, drawnNumber);
double[] checkCDF = new double[distr.length];
double total = 0;
for (int i = 0; i < distr.length; i++) {
total += distr[i];
checkCDF[i] = total;
}
System.err.println("Direct compute = " + new Vector(distr));
System.err.println("Via CDF = " + new Vector(checkCDF));
System.err.println("Check distr = " + new Vector(check));
System.err.println("Q = " + new Vector(Q));
System.err.println("R = " + new Vector(getDtmcProbabilities(1)));
throw new RuntimeException("Likely numerical instability in computing end-conditioned CTMC simulant.");
}
}
}
return drawnNumber;
}
public double[] computePDFDirectly(int startingState, int endingState, double time, double ctmcProbability,
int maxTerm) {
double[] pdf = new double[maxTerm];
final double logRateTime = Math.log(getPoissonRate())+ Math.log(time);
final double logCtmcProbability = Math.log(ctmcProbability);
for (int n = 0; n < maxTerm; n++) {
double[] Rn = getDtmcProbabilities(n);
// pdf[n] = Math.exp(-getPoissonRate() * time) * Math.pow(getPoissonRate() * time, n) /
// Math.exp(GammaFunction.lnGamma(n + 1)) * Rn[startingState * stateCount + endingState] /
// ctmcProbability;
pdf[n] = Math.exp(-getPoissonRate() * time + n * logRateTime - GammaFunction.lnGamma(n + 1) +
Math.log(Rn[startingState * stateCount + endingState]) - logCtmcProbability);
}
return pdf;
}
private double[] makeIndentityMatrx(int stateCount) {
double[] I = new double[stateCount * stateCount];
for (int i = 0; i < stateCount; i++) {
I[i * stateCount + i] = 1.0;
}
return I;
}
private final List<double[]> dtmcCache;
private final double poissonRate;
private final int stateCount;
private final double[] tmp;
private double cachedXForExp = Double.NaN;
private double cachedExpValue;
private static final boolean DEBUG = false;
private static final boolean THROW_EXCEPTION = true;
private double[] Q;
}