/*
* TwoStateSericolaSeriesMarkovReward.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.markovjumps;
import dr.evomodel.substmodel.DefaultEigenSystem;
import dr.evomodel.substmodel.EigenDecomposition;
import dr.evomodel.substmodel.EigenSystem;
import dr.math.Binomial;
import dr.math.GammaFunction;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.Vector;
/**
* @author Marc Suchard
* @author Forrest Crawford
*/
public class TwoStateSericolaSeriesMarkovReward implements MarkovReward {
// Following Bladt, Meini, Neuts and Sericola (2002).
// Assuming each state has a distinct reward, i.e. \phi + 1 = stateCount,
// and states are sorted in increasing reward order
private static final boolean DEBUG = true;
private static final boolean DEBUG2 = false;
public TwoStateSericolaSeriesMarkovReward(double[] Q, double[] r, int dim) {
this(Q, r, dim, 1E-10);
}
public TwoStateSericolaSeriesMarkovReward(double[] Q, double[] r, int dim, double epsilon) {
this.Q = Q;
this.r = r;
this.maxTime = 0;
this.epsilon = epsilon;
this.dim = dim;
lambda = determineLambda();
phi = dim - 1;
if (DEBUG2) {
System.err.println("lambda = " + lambda);
}
P = initializeP(Q, lambda);
eigenSystem = new DefaultEigenSystem(dim);
}
private double[][] initializeW(int times, int dim) {
return new double[times][dim * dim];
}
private int getHfromX(double x, double time) {
// TODO assert x > h[0] * time;
int h = 1;
while (x >= r[h] * time) {
h++;
}
return h;
}
private void growC(double time, int extraN) {
int newN = getNfromC();
if (time > maxTime) {
newN = determineNumberOfSteps(time, lambda) + extraN;
maxTime = time;
}
// Grow C if necessary
if (newN > getNfromC()) {
if (DEBUG && newN > 200) {
System.err.println("Growing C to N = " + newN + " with " + maxTime);
}
if (newN > 500) {
System.err.println("Warning: > 500 recursion depth in SericolaSeriesMarkovReward");
}
initializeSpace(phi, newN);
computeChnk();
}
}
// START: internal structure of C
private double[][][][] internalC; // TODO Linearize for store/restore; TODO reduce to minimal storage
private void initializeSpace(int phi, int N) {
internalC = new double[phi + 1][N + 1][N + 1][dim * dim];
// indices [h][n][k][B_u][B_v]
}
private double[] C(int h, int n, int k) {
return internalC[h][n][k];
}
private int getNfromC() {
return (internalC == null) ? -1 : internalC[0].length - 1;
}
private int idx(int i, int j) {
return i * dim + j; // row-major
}
// END: internal structure of C, TODO Change to expandable list
private int[] getHfromX(double[] X, double time) {
int[] H = new int[X.length];
for (int i = 0; i < X.length; ++i) {
H[i] = getHfromX(X[0], time);
}
return H;
}
public double computePdf(double x, double time, int i, int j) {
growC(time, 1);
int uv = i * dim + j;
// double[][] W = initializeW(1, dim); // initialize with zeros
double w = 0.0;
final int N = getNfromC() - 1;
for (int n = 0; n <= N; ++n) {
w += accumulatePdf(x, n, time, uv);
}
// if (DEBUG2) {
// for (int i = 0; i < W.length; ++i) {
// System.err.println("W'[" + i + "]:\n" + new Matrix(squareMatrix(W[i])));
// }
// System.err.println("");
// }
return w;
}
public double[] computePdf(double x, double time) {
return computePdf(new double[]{x}, time)[0];
}
public double[][] computePdf(double[] X, double time) {
// int[] H = getHfromX(X, time);
growC(time, 1);
double[][] W = initializeW(X.length, dim); // initialize with zeros
final int N = getNfromC() - 1;
for (int n = 0; n <= N; ++n) {
accumulatePdf(W, X, n, time); // TODO This can be sped up when only a single entry is wanted
}
if (DEBUG2) {
for (int i = 0; i < W.length; ++i) {
System.err.println("W'[" + i + "]:\n" + new Matrix(squareMatrix(W[i])));
}
System.err.println("");
}
return W;
}
public double computeCdf(double x, double time, int i, int j) {
return computeCdf(x, time)[i * dim + j];
}
public double[] computeCdf(double x, double time) {
return computeCdf(new double[]{x}, time)[0];
}
public double[][] computeCdf(double[] X, double time) {
int[] H = getHfromX(X, time);
growC(time, 0);
double[][] W = initializeW(X.length, dim); // initialize with zeros
final int N = getNfromC();
for (int n = 0; n <= N; ++n) {
accumulateCdf(W, X, H, n, time);
}
if (DEBUG2) {
for (int i = 0; i < W.length; ++i) {
System.err.println("W[" + i + "]:\n" + new Matrix(squareMatrix(W[i])));
}
System.err.println("");
}
return W;
}
private double[] initializeP(double[] Q, double lambda) {
double[] P = new double[dim * dim];
for (int i = 0; i < dim; ++i) {
for (int j = 0; j < dim; ++j) {
double identity = (i == j) ? 1.0 : 0.0;
P[idx(i, j)] = identity + Q[idx(i, j)] / lambda;
}
}
return P;
}
private void accumulateCdf(double[][] W, double[] X, int[] H, int n, double time) {
final double premult = Math.exp(
-lambda * time + n * (Math.log(lambda) + Math.log(time)) - GammaFunction.lnGamma(n + 1.0)
);
// TODO Make factorial/choose static look-up tables
for (int t = 0; t < X.length; ++t) { // For each time point
double x = X[t];
int h = H[t];
double xh = (x - r[h - 1] * time) / ((r[h] - r[h - 1]) * time);
final int dim2 = dim * dim;
double[] inc = new double[dim2]; // W^{\epsilon}(x(i),t,n)
for (int k = 0; k <= n; k++) {
final double binomialCoef = Binomial.choose(n, k) * Math.pow(xh, k) * Math.pow(1.0 - xh, n - k);
for (int uv = 0; uv < dim2; ++uv) {
inc[uv] += binomialCoef * C(h, n, k)[uv];
}
}
for (int uv = 0; uv < dim2; ++uv) {
W[t][uv] += premult * inc[uv];
}
}
}
private double accumulatePdf(double x, int n, double time, int uv) {
double w = 0.0;
final double premult = Math.exp(
-lambda * time + n * (Math.log(lambda) + Math.log(time)) - GammaFunction.lnGamma(n + 1.0)
);
// TODO Make factorial/choose static look-up tables
// for (int t = 0; t < X.length; ++t) { // For each time point
int h = 1;
final double factor = lambda / (r[h] - r[h - 1]);
double xh = (x - r[h - 1] * time) / ((r[h] - r[h - 1]) * time);
// final int dim2 = dim * dim;
// double[] inc = new double[dim2]; // W^{\epsilon}(x(i),t,n)
double inc = 0.0;
for (int k = 0; k <= n; k++) {
final double binomialCoef = Binomial.choose(n, k) * Math.pow(xh, k) * Math.pow(1.0 - xh, n - k);
// for (int uv = 0; uv < dim2; ++uv) {
inc += binomialCoef * (C(h, n + 1, k + 1)[uv] - C(h, n + 1, k)[uv]);
// }
}
// for (int uv = 0; uv < dim2; ++uv) {
w += factor * premult * inc;
// }
// }
return w;
}
private void accumulatePdf(double[][] W, double[] X, int n, double time) {
final double premult = Math.exp(
-lambda * time + n * (Math.log(lambda) + Math.log(time)) - GammaFunction.lnGamma(n + 1.0)
);
// TODO Make factorial/choose static look-up tables
for (int t = 0; t < X.length; ++t) { // For each time point
double x = X[t];
int h = 1;
final double factor = lambda / (r[h] - r[h - 1]);
double xh = (x - r[h - 1] * time) / ((r[h] - r[h - 1]) * time);
final int dim2 = dim * dim;
double[] inc = new double[dim2]; // W^{\epsilon}(x(i),t,n)
for (int k = 0; k <= n; k++) {
final double binomialCoef = Binomial.choose(n, k) * Math.pow(xh, k) * Math.pow(1.0 - xh, n - k);
for (int uv = 0; uv < dim2; ++uv) {
inc[uv] += binomialCoef * (C(h, n + 1, k + 1)[uv] - C(h, n + 1, k)[uv]);
}
}
for (int uv = 0; uv < dim2; ++uv) {
W[t][uv] += factor * premult * inc[uv];
}
}
}
private double relationTwelve(int h, int n, int k, int u, int v) {
// TODO ratios are independent of u,v,w
double c = (r[u] - r[h]) / (r[u] - r[h - 1]) * C(h, n, k - 1)[idx(u, v)];
double d = 0;
for (int w = 0; w <= phi; ++w) {
d += P[idx(u, w)] * C(h, n - 1, k - 1)[idx(w, v)];
}
d *= (r[h] - r[h - 1]) / (r[u] - r[h - 1]);
return c + d;
}
private double relationThirteen(int h, int n, int k, int u, int v) {
// TODO ratios of are independent of u,v,w
double c = (r[h - 1] - r[u]) / (r[h] - r[u]) * C(h, n, k + 1)[idx(u, v)];
double d = 0;
for (int w = 0; w <= phi; ++w) {
d += P[idx(u, w)] * C(h, n - 1, k)[idx(w, v)];
}
d *= (r[h] - r[h - 1]) / (r[h] - r[u]);
return c + d;
}
private double[] product(double[] a, double[] b) {
double[] c = new double[dim * dim];
for (int i = 0; i < dim; ++i) {
for (int j = 0; j < dim; ++j) {
final int ij = idx(i, j);
for (int k = 0; k < dim; ++k) {
c[ij] += a[idx(i, k)] * b[idx(k, j)];
}
}
}
return c;
}
private void computeChnk() {
double[] Pn = new double[dim * dim];
for (int u = 0; u < dim; ++u) {
Pn[idx(u, u)] = 1.0;
}
for (int h = 1; h <= phi; ++h) {
// zero corner cases automatically
for (int u = 0; u <= h - 1; ++u) {
C(h, 0, 0)[idx(u, u)] = 1.0;
}
}
// accumulate(0);
final int N = getNfromC();
for (int n = 1; n <= N; ++n) {
// zero corner cases automatically
for (int h = 1; h <= phi; ++h) {
for (int k = 1; k <= n; ++k) {
for (int u = h; u <= phi; ++u) {
for (int v = 0; v <= phi; ++v) {
C(h, n, k)[idx(u, v)] = relationTwelve(h, n, k, u, v);
}
}
}
for (int u = h + 1; u <= phi; ++u) {
for (int v = 0; v <= phi; ++v) {
C(h + 1, n, 0)[idx(u, v)] = C(h, n, n)[idx(u, v)];
}
}
}
Pn = product(Pn, P);
for (int u = 0; u <= phi - 1; ++u) {
for (int v = 0; v <= phi; ++v) {
C(phi, n, n)[idx(u, v)] = Pn[idx(u, v)];
}
}
for (int h = phi; h >= 1; --h) {
for (int k = n - 1; k >= 0; --k) {
for (int u = 0; u <= h - 1; u++) {
for (int v = 0; v <= phi; ++v) {
C(h, n, k)[idx(u, v)] = relationThirteen(h, n, k, u, v);
}
}
for (int u = 0; u <= h - 2; ++u) {
for (int v = 0; v <= phi; ++v) {
C(h - 1, n, n)[idx(u, v)] = C(h, n, 0)[idx(u, v)];
}
}
}
}
// accumulate(n);
}
}
private double determineLambda() {
double lambda = Q[0]; // Q[idx(0,0)]
for (int i = 1; i < dim; ++i) {
int ii = idx(i, i);
if (Q[ii] < lambda) {
lambda = Q[ii];
}
}
return -lambda;
}
private double[][] squareMatrix(final double[] mat) {
double[][] rtn = new double[dim][dim];
for (int i = 0; i < dim; ++i) {
for (int j = 0; j < dim; ++j) {
rtn[i][j] = mat[idx(i, j)];
}
}
return rtn;
}
private int determineNumberOfSteps(double time, double lambda) {
// final double tolerance = (1.0 - epsilon) / Math.exp(-lambda * time);
// final double logTolerance = Math.log(1.0 - epsilon);
int i = -1;
// double sum = 0.0;
// int factorialI = 1;
//
// while (sum < tolerance) {
// i++;
// sum += Math.pow(lambda * time, i) / factorialI;
// factorialI *= (i + 1); // + 1 because used on next iterate
// }
//
// int firstN = i;
//
// LogTricks.logSum()
// i = -1;
final double tolerance2 =
// -epsilon;
(1.0 - epsilon);
double sum2 = 0.0;
// double sum2 = Double.NEGATIVE_INFINITY;
while (Math.abs(sum2 - tolerance2) > epsilon && sum2 < 1.0) {
// while (sum2 < tolerance2) {
i++;
double logDensity = -lambda * time + i * (Math.log(lambda) + Math.log(time)) - GammaFunction.lnGamma(i + 1);
sum2 += Math.exp(logDensity);
// sum2 = LogTricks.logSum(sum2, logDensity);
if (DEBUG2) {
System.err.println(sum2 + " " + tolerance2 + " " + Math.abs(sum2 - tolerance2) + " " + epsilon * 0.01);
// if (i > 500) System.exit(-1);
}
}
// System.err.println("First: " + firstN);
// System.err.println("Second:" + i);
// System.exit(-1);
return i;
}
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("Q: " + new Vector(Q) + "\n");
sb.append("r: " + new Vector(r) + "\n");
sb.append("lambda: " + lambda + "\n");
sb.append("N: " + getNfromC() + "\n");
sb.append("maxTime: " + maxTime + "\n");
sb.append("cprob at maxTime: " + new Vector(computeConditionalProbabilities(maxTime)) + "\n");
return sb.toString();
}
private EigenDecomposition getEigenDecomposition() {
if (eigenDecomposition == null) {
eigenDecomposition = eigenSystem.decomposeMatrix(squareMatrix(Q));
}
return eigenDecomposition;
}
private EigenDecomposition eigenDecomposition;
public double[] computeConditionalProbabilities(double distance) {
double[] matrix = new double[dim * dim];
eigenSystem.computeExponential(getEigenDecomposition(), distance, matrix);
return matrix;
}
public double computeConditionalProbability(double distance, int i, int j) {
return eigenSystem.computeExponential(getEigenDecomposition(), distance, i, j);
}
private final double[] Q;
private final double[] r;
private final double lambda;
private final double[] P;
private final int phi;
private final int dim;
private final double epsilon;
private final EigenSystem eigenSystem;
private double maxTime;
}