/* * MarkovModulatedFrequencyModel.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.substmodel; import cern.colt.matrix.DoubleMatrix2D; import cern.colt.matrix.impl.DenseDoubleMatrix2D; import cern.colt.matrix.linalg.Algebra; import cern.colt.matrix.linalg.LUDecomposition; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import java.util.List; /** * @author Marc A. Suchard * @author Vladimir Minin * <p/> * A class for implementing a kronecker sum of CTMC models in BEAST using BEAGLE * This work is supported by NSF grant 0856099 * <p/> * O'Brien JD, Minin VN and Suchard MA (2009) Learning to count: robust estimates for labeled distances between * molecular sequences. Molecular Biology and Evolution, 26, 801-814 */ public class MarkovModulatedFrequencyModel extends FrequencyModel { public MarkovModulatedFrequencyModel(String name, List<FrequencyModel> freqModels, Parameter switchingRates) { super(name); this.freqModels = freqModels; int freqCount = 0; stateCount = freqModels.get(0).getFrequencyCount(); numBaseModel = freqModels.size(); for (int i = 0; i < numBaseModel; i++) { int size = freqModels.get(i).getFrequencyCount(); if (stateCount != size) { throw new RuntimeException("MarkovModulatedFrequencyModel requires all frequencies model to have the same dimension"); } addModel(freqModels.get(i)); freqCount += size; } totalFreqCount = freqCount; this.switchingRates = switchingRates; addVariable(switchingRates); // if (switchingRates.getDimension() > 2 || numBaseModel > 2) { // throw new RuntimeException("MarkovModulatedFrequencyModel not yet implemented for more than 2 hidden classes"); // } baseStationaryDistribution = new double[numBaseModel]; storedBaseStationaryDistribution = new double[numBaseModel]; stationaryDistributionKnown = false; d = new DenseDoubleMatrix2D(numBaseModel, numBaseModel); d.set(0, 0, 1.0); } public void setFrequency(int i, double value) { throw new RuntimeException("Not implemented"); } public double getFrequency(int index) { int whichModel = index / stateCount; int whichState = index % stateCount; double relativeFreq = freqModels.get(whichModel).getFrequency(whichState) / numBaseModel; // Scale by stationary distribution over hidden classes if (numBaseModel > 1) { if (!stationaryDistributionKnown) { computeStationaryDistribution(baseStationaryDistribution); stationaryDistributionKnown = true; } // relativeFreq *= baseStationaryDistribution[whichModel]; // Try: no adjustment, appears to cause store/restore issue } return relativeFreq; } private void computeStationaryDistribution(double[] statDistr) { // // Uses an eigendecomposition and matrix inverse // DoubleMatrix2D mat = new DenseDoubleMatrix2D(numBaseModel, numBaseModel); // int index = 0; // for (int i = 0; i < numBaseModel; ++i) { // for (int j = i + 1; j < numBaseModel; ++j) { // mat.set(i, j, switchingRates.getParameterValue(index)); // index++; // } // } // for (int j = 0; j < numBaseModel; ++j) { // for (int i = j + 1; i < numBaseModel; ++i) { // mat.set(i, j, switchingRates.getParameterValue(index)); // index++; // } // } // for (int i = 0; i < numBaseModel; ++i) { // double rowTotal = 0.0; // for (int j = 0; j < numBaseModel; ++j) { // if (i != j) { // rowTotal += mat.get(i,j); // } // } // mat.set(i,i, -rowTotal); // } // // EigenvalueDecomposition ed = new EigenvalueDecomposition(mat); // DoubleMatrix2D eigenVectors = ed.getV(); // DoubleMatrix2D b = alg.mult(eigenVectors, alg.mult(d, alg.inverse(eigenVectors))); // // for (int i = 0; i < numBaseModel; ++i) { // statDistr[i] = b.get(0,i); // } // System.err.println(new Vector(statDistr)); if (allRatesAreZero(switchingRates)) { return; } // Uses an LU decomposition to solve Q^t \pi = 0 and \sum \pi_i = 1 DoubleMatrix2D mat2 = new DenseDoubleMatrix2D(numBaseModel + 1, numBaseModel); int index2 = 0; for (int i = 0; i < numBaseModel; ++i) { for (int j = i + 1; j < numBaseModel; ++j) { mat2.set(j, i, switchingRates.getParameterValue(index2)); // Transposed index2++; } } for (int j = 0; j < numBaseModel; ++j) { for (int i = j + 1; i < numBaseModel; ++i) { mat2.set(j, i, switchingRates.getParameterValue(index2)); // Transposed index2++; } } for (int i = 0; i < numBaseModel; ++i) { double rowTotal = 0.0; for (int j = 0; j < numBaseModel; ++j) { if (i != j) { rowTotal += mat2.get(j, i); // Transposed } } mat2.set(i, i, -rowTotal); } // Add row for sum-to-one constraint for (int i = 0; i < numBaseModel; ++i) { mat2.set(numBaseModel, i, 1.0); } LUDecomposition decomp = new LUDecomposition(mat2); DoubleMatrix2D x = new DenseDoubleMatrix2D(numBaseModel + 1, 1); x.set(numBaseModel, 0, 1.0); DoubleMatrix2D y = decomp.solve(x); for (int i = 0; i < numBaseModel; ++i) { statDistr[i] = y.get(i, 0); } //System.err.println(new Vector(statDistr)); } private static boolean allRatesAreZero(Parameter rates) { for (int i = 0; i < rates.getDimension(); ++i) { if (rates.getParameterValue(i) != 0.0) { return false; } } return true; } protected void storeState() { System.arraycopy(baseStationaryDistribution, 0, storedBaseStationaryDistribution, 0, numBaseModel); storedStationaryDistributionKnown = stationaryDistributionKnown; } protected void restoreState() { double[] tmp = baseStationaryDistribution; baseStationaryDistribution = storedBaseStationaryDistribution; storedBaseStationaryDistribution = tmp; stationaryDistributionKnown = storedStationaryDistributionKnown; } protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { if (variable == switchingRates) { stationaryDistributionKnown = false; } } protected void handleModelChangedEvent(Model model, Object object, int index) { // System.err.println("MMFM.hMCE : " + model.getId() + " : " + model.getClass().getCanonicalName()); fireModelChanged(); } public int getFrequencyCount() { return totalFreqCount; } public Parameter getFrequencyParameter() { throw new RuntimeException("Not implemented"); } private List<FrequencyModel> freqModels; private final int numBaseModel; private final int totalFreqCount; private final int stateCount; private final Parameter switchingRates; private double[] baseStationaryDistribution; private double[] storedBaseStationaryDistribution; private boolean stationaryDistributionKnown; private boolean storedStationaryDistributionKnown; private final Algebra alg = new Algebra(); private final DoubleMatrix2D d; }