/*
* MarkovModulatedSubstitutionModel.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.substmodel;
import dr.evolution.datatype.DataType;
import dr.evomodel.siteratemodel.SiteRateModel;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.Vector;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.CommonCitations;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.logging.Logger;
/**
* @author Marc A. Suchard
*/
public class MarkovModulatedSubstitutionModel extends ComplexSubstitutionModel implements Citable, Loggable {
private List<SubstitutionModel> baseModels;
private final int numBaseModel;
private final int baseStateCount;
// private final int stateCount;
private final Parameter switchingRates;
private static final boolean IGNORE_RATES = false;
private static final boolean DEBUG = false;
private static final boolean NEW_STORE_RESTORE = true;
private final double[] baseMatrix;
private Parameter rateScalar;
private boolean birthDeathModel;
private boolean geometricRates;
private final SiteRateModel gammaRateModel;
private EigenDecomposition storedEigenDecomposition;
private boolean storedUpdateMatrix;
public MarkovModulatedSubstitutionModel(String name,
List<SubstitutionModel> baseModels,
Parameter switchingRates,
DataType dataType,
EigenSystem eigenSystem) {
this(name, baseModels, switchingRates, dataType, eigenSystem, null, false, null);
}
public MarkovModulatedSubstitutionModel(String name,
List<SubstitutionModel> baseModels,
Parameter switchingRates,
DataType dataType,
EigenSystem eigenSystem,
Parameter rateScalar,
boolean geometricRates,
SiteRateModel gammaRateModel) {
// super(name, dataType, null, eigenSystem);
super(name, dataType, null, null);
this.baseModels = baseModels;
numBaseModel = baseModels.size();
if (numBaseModel == 0) {
throw new RuntimeException("May not construct MarkovModulatedSubstitutionModel with 0 base models");
}
this.switchingRates = switchingRates;
addVariable(switchingRates);
if (switchingRates.getDimension() != 2 * (numBaseModel - 1)
&& switchingRates.getDimension() != numBaseModel * (numBaseModel - 1)
) {
throw new RuntimeException("Wrong switching rate dimensions");
}
List<FrequencyModel> freqModels = new ArrayList<FrequencyModel>();
int stateSizes = 0;
baseStateCount = baseModels.get(0).getFrequencyModel().getFrequencyCount();
baseMatrix = new double[baseStateCount * baseStateCount];
for (int i = 0; i < numBaseModel; i++) {
addModel(baseModels.get(i));
freqModels.add(baseModels.get(i).getFrequencyModel());
addModel(baseModels.get(i).getFrequencyModel());
DataType thisDataType = baseModels.get(i).getDataType();
stateSizes += thisDataType.getStateCount();
}
// This constructor also checks that all models have the same base stateCount
freqModel = new MarkovModulatedFrequencyModel("mm", freqModels, switchingRates);
addModel(freqModel);
if (stateCount != stateSizes) {
throw new RuntimeException("Incompatible state counts in " + getModelName() + ". Models add up to " + stateSizes + ".");
}
birthDeathModel = true;
this.geometricRates = geometricRates;
// Check switching rate dimension
if (numBaseModel > 1) {
if (switchingRates.getDimension() != 2 * (numBaseModel - 1)) {
birthDeathModel = false;
// throw new RuntimeException("Wrong dimension of switching rates in MarkovModulatedSubstitutionModel " + switchingRates.getDimension() + " " + 2 * (numBaseModel - 1) + " " + numBaseModel);
}
}
if (gammaRateModel != null) {
addModel(gammaRateModel);
if (gammaRateModel.getCategoryCount() != numBaseModel && numBaseModel % gammaRateModel.getCategoryCount() != 0) {
throw new RuntimeException("Wrong discretized gamma dimension");
}
}
this.gammaRateModel = gammaRateModel;
if (rateScalar != null) {
addVariable(rateScalar);
if (rateScalar.getDimension() != 1 && rateScalar.getDimension() != numBaseModel) {
throw new RuntimeException("Wrong rate scalar dimensions");
}
}
this.rateScalar = rateScalar;
setDoNormalization(false);
updateMatrix = true;
Logger.getLogger("dr.app.beagle").info("\tConstructing a Markov-modulated Markov chain substitution model with " + stateCount + " states; please cite:\n"
+ Citable.Utils.getCitationString(this));
}
public int getNumBaseModel() {
return numBaseModel;
}
public double getModelRateScalar(int model) {
if (gammaRateModel != null) {
model = model % gammaRateModel.getCategoryCount();
if (DEBUG) {
System.err.println("M" + model + " = " + gammaRateModel.getRateForCategory(model));
}
return gammaRateModel.getRateForCategory(model);
}
if (rateScalar == null) {
return 1.0;
} else {
if (rateScalar.getDimension() == 1) {
return rateScalar.getParameterValue(0);
} else {
return rateScalar.getParameterValue(model);
}
}
}
protected void storeState() {
if (DEBUG) {
System.err.println("MMSM.sS");
}
if (NEW_STORE_RESTORE) {
if (eigenDecomposition != null) {
storedEigenDecomposition = eigenDecomposition.copy();
}
storedUpdateMatrix = updateMatrix;
} else {
super.storeState();
}
}
protected void restoreState() {
if (DEBUG) {
System.err.println("MMSM.rS");
}
if (NEW_STORE_RESTORE) {
EigenDecomposition tmp = storedEigenDecomposition;
storedEigenDecomposition = eigenDecomposition;
eigenDecomposition = tmp;
updateMatrix = storedUpdateMatrix;
} else {
super.restoreState();
}
}
protected void setupQMatrix(double[] rates, double[] pi, double[][] matrix) {
// System.err.println("MMSM.sQM");
// Zero matrix
for (int i = 0; i < matrix.length; ++i) {
Arrays.fill(matrix[i], 0.0);
}
// Set the instantaneous rate matrix
for (int m = 0; m < numBaseModel; ++m) {
final int offset = m * baseStateCount;
baseModels.get(m).getInfinitesimalMatrix(baseMatrix);
if (DEBUG) {
System.err.println("m " + m + " : " + new dr.math.matrixAlgebra.Vector(baseMatrix));
}
final double rateScalar = getModelRateScalar(m);
int k = 0;
for (int i = 0; i < baseStateCount; i++) {
for (int j = 0; j < baseStateCount; j++) {
matrix[offset + i][offset + j] = rateScalar * baseMatrix[k];
k++;
}
}
}
// Add switching rates to matrix
if (!IGNORE_RATES && numBaseModel > 1) {
double[] swRates = switchingRates.getParameterValues();
double totalRate = 0.0;
for (double rate : swRates) {
totalRate += rate;
}
int sw = 0;
for (int g = 0; g < numBaseModel; ++g) {
for (int h = 0; h < numBaseModel; ++h) { // from g -> h
boolean valid = birthDeathModel ? Math.abs(g - h) == 1 : g != h;
if (valid) {
double rate = swRates[sw];
if (geometricRates) {
rate *= getModelRateScalar(numBaseModel - h - 1) /// numBaseModel; // TODO Why not: "/ numBaseModel" ??
/ totalRate;
}
for (int i = 0; i < baseStateCount; ++i) {
matrix[g * baseStateCount + i][h * baseStateCount + i] = rate;
}
sw++;
}
}
}
}
// if (DEBUG) {
// System.err.println(new Matrix(matrix));
// }
}
// protected double setupMatrix() {
//// System.err.println("In MM.setupMatrix");
//// setupRelativeRates(relativeRates);
//// double[] pi = freqModel.getFrequencies();
// setupQMatrix(null, null, q);
//// makeValid(q, stateCount);
// return 1.0;
// }
// public FrequencyModel getFrequencyModel() {
// return pcFreqModel;
// }
// TODO Remove
public EigenDecomposition getEigenDecomposition() {
if (DEBUG) {
System.err.println("MMSM.getED");
}
EigenDecomposition ed = super.getEigenDecomposition();
if (DEBUG) {
double[][] q = getQCopy();
System.err.println(new Matrix(q));
System.err.println("");
System.err.println(new dr.math.matrixAlgebra.Vector(ed.getEigenValues()));
System.err.println("");
double[] tp = new double[q.length * q.length];
getTransitionProbabilities(1.0, tp, ed);
System.err.println(new Vector(tp));
}
return ed;
}
@Override
public Citation.Category getCategory() {
return Citation.Category.SUBSTITUTION_MODELS;
}
@Override
public String getDescription() {
return "Markov modulated substitution model";
}
public List<Citation> getCitations() {
return Collections.singletonList(CommonCitations.SUCHARD_2012);
}
@Override
protected void frequenciesChanged() {
// Do nothing
}
@Override
protected void ratesChanged() {
updateMatrix = true; // Lazy recompute relative rates
}
@Override
protected void setupRelativeRates(double[] rates) {
// Do nothing
}
protected void handleModelChangedEvent(Model model, Object object, int index) {
if (DEBUG) {
System.err.println("MMSM.hMCE");
}
// base substitution model changed!
updateMatrix = true;
// frequenciesChanged();
// System.err.println("Model " + model.getId() + " changed");
fireModelChanged();
}
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
// System.err.println("Variable " + variable.getId() + " changed");
if (variable == switchingRates || variable == rateScalar) {
// Update rates
updateMatrix = true;
// if (variable == rateScalar)
fireModelChanged(); // TODO Determine if necessary for ExposeRateCategoriesWrapper
}
// else do nothing, action taken care of at individual base models
}
public LogColumn[] getColumns() {
List<LogColumn> columns = new ArrayList<LogColumn>();
for (LogColumn parentColumn : super.getColumns()) {
columns.add(parentColumn);
}
for (int i = 0; i < numBaseModel; ++i) {
String label = "rateScalar." + i;
columns.add(new RateColumn(label, i));
}
return columns.toArray(new LogColumn[0]);
}
private class RateColumn extends NumberColumn {
private final int index;
public RateColumn(String label, int index) {
super(label);
this.index = index;
}
/**
* Returns the current value as a double.
*/
@Override
public double getDoubleValue() {
return getModelRateScalar(index);
}
}
}