/*
* StateHistory.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.markovjumps;
import dr.evolution.datatype.DataType;
import dr.math.MathUtils;
import java.util.ArrayList;
import java.util.List;
/**
* A class to represent the complete state history of a continuous-time Markov chain in the
* interval [0,T].
* <p/>
* This work is supported by NSF grant 0856099
*
* @author Marc A. Suchard
*/
public class StateHistory {
// private StateHistory(int startingState, int stateCount) {
// this(0.0, startingState, stateCount);
// }
protected StateHistory(double startingTime, int startingState, int stateCount) {
stateList = new ArrayList<StateChange>();
stateList.add(new StateChange(startingTime, startingState));
this.stateCount = stateCount;
finalized = false;
}
public void addChange(StateChange stateChange) {
checkFinalized(false);
stateList.add(stateChange);
}
public void addEndingState(StateChange stateChange) {
checkFinalized(false);
stateList.add(stateChange);
finalized = true;
}
public int[] getJumpCounts() {
int[] counts = new int[stateCount * stateCount];
accumulateSufficientStatistics(counts, null);
return counts;
}
public double[] getWaitingTimes() {
double[] times = new double[stateCount];
accumulateSufficientStatistics(null, times);
return times;
}
public double getTotalRegisteredCounts(double[] register) {
int[] counts = getJumpCounts();
// double total = 0;
// for (int i = 0; i < counts.length; i++) {
// total += counts[i] * register[i];
// }
// return total;
return dotProduct(counts, register);
}
private double dotProduct(int[] a, double[] b) {
double total = 0;
final int length = a.length;
for (int i = 0; i < length; i++) {
total += a[i] * b[i];
}
return total;
}
public double getTotalReward(double[] register) {
double[] times = getWaitingTimes();
double total = 0;
for (int i = 0; i < times.length; i++) {
total += times[i] * register[i]; // stateCount length vector
}
return total;
}
public void accumulateSufficientStatistics(int[] counts, double[] times) {
checkFinalized(true);
int nJumps = getNumberOfJumps();
StateChange initialState = stateList.get(0);
int currentState = initialState.getState();
double currentTime = initialState.getTime();
for (int i = 1; i <= nJumps; i++) {
StateChange nextStateChange = stateList.get(i);
int nextState = nextStateChange.getState();
double nextTime = nextStateChange.getTime();
if (counts != null) {
counts[currentState * stateCount + nextState]++;
}
if (times != null) {
times[currentState] += (nextTime - currentTime);
}
currentState = nextState;
currentTime = nextTime;
}
if (times != null) { // Add last waiting time
StateChange finalState = stateList.get(nJumps + 1);
times[currentState] += (finalState.getTime() - currentTime);
}
}
public int getNumberOfJumps() {
checkFinalized(true);
return stateList.size() - 2; // Discount starting and ending states
}
private void checkFinalized(boolean isTrue) {
if (isTrue != finalized) {
throw new IllegalAccessError("StateHistory " + (finalized ? "is" : "is not" + " finalized"));
}
}
public int getStartingState() {
return stateList.get(0).getState();
}
public int getEndingState() {
checkFinalized(true);
return stateList.get(stateList.size() - 1).getState();
}
public double getStartingTime() {
return stateList.get(0).getTime();
}
public double getEndingTime() {
checkFinalized(true);
return stateList.get(stateList.size() - 1).getTime();
}
public void rescaleTimesOfEvents(double inStartTime, double inEndTime) {
final double scale = (inEndTime - inStartTime) / (getEndingTime() - getStartingTime());
StateChange currentStateChange = stateList.get(0);
double oldCurrentTime = currentStateChange.getTime();
currentStateChange.setTime(inStartTime);
double newCurrentTime = inStartTime;
for (int i = 1; i < stateList.size(); ++i) {
StateChange nextStateChange = stateList.get(i);
double oldNextTime = nextStateChange.getTime();
double oldTimeDiff = oldNextTime - oldCurrentTime;
double newNextTime = oldTimeDiff * scale + newCurrentTime;
nextStateChange.setTime(newNextTime);
oldCurrentTime = oldNextTime;
newCurrentTime = newNextTime;
}
}
public String toString() {
StringBuilder sb = new StringBuilder("[");
for (int i = 0; i < stateList.size(); i++) {
if (i > 0) {
sb.append(",");
}
sb.append(stateList.get(i));
}
sb.append("]");
return sb.toString();
}
public static void main(String[] args) {
System.err.println("Testing time rescaling:");
StateHistory stateHistory = new StateHistory(1, 1, 4);
StateChange stateChange;
stateChange = new StateChange(2, 2);
stateHistory.addChange(stateChange);
stateChange = new StateChange(5, 2);
stateHistory.addEndingState(stateChange);
System.err.println("Initial history: " + stateHistory);
stateHistory.rescaleTimesOfEvents(8.0, 0.0);
System.err.println("Rescale history: " + stateHistory);
stateHistory.rescaleTimesOfEvents(0.0, 4.0);
System.err.println("Rescale history: " + stateHistory);
}
public StateHistory filterChanges(double[] register) {
if (getNumberOfJumps() == 0) {
return this;
}
StateChange currentState = stateList.get(0);
StateHistory newHistory = new StateHistory(currentState.getTime(), currentState.getState(), stateCount);
for (int i = 1; i < stateList.size() - 1; ++i) {
StateChange nextState = stateList.get(i);
if (register[currentState.getState() * stateCount + nextState.getState()] == 1) {
nextState = nextState.clone();
nextState.setPreviousState(currentState.getState());
newHistory.addChange(nextState);
}
currentState = nextState;
}
newHistory.addEndingState(stateList.get(stateList.size() - 1));
// This function can produce inconsistent histories when not all changes are reported.
isFiltered = true;
return newHistory;
}
public double getLogLikelihood(final double[] infinitesimalRates, final int stateCount) {
checkFinalized(true);
// TODO This function needs testing
double logLikelihood = 0.0;
final int totalChanges = getNumberOfJumps();
int currentState = stateList.get(0).getState();
double currentTime = stateList.get(0).getTime();
for (int i = 1; i < totalChanges; ++i) {
int nextState = stateList.get(i).getState();
double nextTime = stateList.get(i).getTime();
// Exponential pdf and destination choice
logLikelihood += Math.log(infinitesimalRates[currentState * stateCount + nextState])
+ infinitesimalRates[currentState * stateCount + currentState] * (nextTime - currentTime);
// terms involving Math.log(\lambda_{ii}) cancel
currentState = nextState;
currentTime = nextTime;
}
final int lastState = stateList.get(stateList.size() - 1).getState();
final double lastTime = stateList.get(stateList.size() - 1).getTime();
assert (lastState == currentState);
assert (lastTime >= currentTime);
// No event in last interval
logLikelihood += infinitesimalRates[currentState * stateCount + currentState] * (lastTime - currentTime);
return logLikelihood;
}
public String toStringChanges(int site, DataType dataType) {
return toStringChanges(site, dataType, true);
}
public String toStringChanges(int site, DataType dataType, boolean wrap) {
StringBuilder sb = wrap ? new StringBuilder("{") : new StringBuilder();
// site number gets put into each and every event string
// sb.append(site).append(",");
int currentState = stateList.get(0).getState();
boolean firstChange = true;
for (int i = 1; i < stateList.size() - 1; i++) { // TODO Code review: should this really be size() - 1?
int nextState = stateList.get(i).getState();
if (isFiltered) {
currentState = stateList.get(i).getPreviousState();
}
if (nextState != currentState) {
if (!firstChange) {
sb.append(",");
}
double time = stateList.get(i).getTime(); // + startTime;
addEventToStringBuilder(sb, dataType.getCode(currentState), dataType.getCode(nextState), time, site);
firstChange = false;
currentState = nextState;
}
}
if (wrap) {
sb.append("}"); // Always returns an array of arrays
}
return sb.toString();
}
public static void addEventToStringBuilder(StringBuilder sb, String source, String dest, double time, int site) {
// AR changed this to match an attribute array:
sb.append("{");
if (site > 0) {
sb.append(site).append(",");
}
sb.append(time).append(",").append(source).append(",").append(dest).append("}");
}
public static StateHistory simulateConditionalOnEndingState(double startingTime,
int startingState,
double endingTime,
int endingState,
double[] lambda,
int stateCount) {
throw new RuntimeException("Impossible to simulate a conditioned CTMC in StateHistory");
}
public static StateHistory simulateUnconditionalOnEndingState(double startingTime,
int startingState,
double endingTime,
double[] lambda,
int stateCount) {
StateHistory history = new StateHistory(startingTime, startingState, stateCount);
double[] multinomial = new double[stateCount];
double currentTime = startingTime;
int currentState = startingState;
while (currentTime < endingTime) {
double currentRate = -lambda[currentState * stateCount + currentState];
double waitingTime = MathUtils.nextExponential(currentRate);
currentTime += waitingTime;
if (currentTime < endingTime) { // Simulate a jump
System.arraycopy(lambda, currentState * stateCount, multinomial, 0, stateCount);
multinomial[currentState] = 0;
currentState = MathUtils.randomChoicePDF(multinomial); // Does not need to be normalized
history.addChange(new StateChange(currentTime, currentState));
}
}
history.addEndingState(new StateChange(endingTime, currentState));
return history;
}
private int stateCount;
private List<StateChange> stateList;
private boolean finalized;
private boolean isFiltered = false;
}