/* * ProductChainSubstitutionModel.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.substmodel; import dr.evomodel.siteratemodel.SiteRateModel; import dr.evolution.datatype.DataType; import dr.evolution.datatype.GeneralDataType; import dr.inference.model.Model; import dr.math.KroneckerOperation; import dr.util.Citable; import dr.util.Citation; import dr.util.CommonCitations; import java.util.*; import java.util.logging.Logger; //import dr.math.matrixAlgebra.Vector; /** * @author Marc A. Suchard * @author Vladimir Minin * <p/> * A class for implementing a kronecker sum of CTMC models in BEAST using BEAGLE * This work is supported by NSF grant 0856099 * <p/> * O'Brien JD, Minin VN and Suchard MA (2009) Learning to count: robust estimates for labeled distances between * molecular sequences. Molecular Biology and Evolution, 26, 801-814 */ public class ProductChainSubstitutionModel extends BaseSubstitutionModel implements Citable { public ProductChainSubstitutionModel(String name, List<SubstitutionModel> baseModels) { this(name, baseModels, null); } public ProductChainSubstitutionModel(String name, List<SubstitutionModel> baseModels, List<SiteRateModel> rateModels) { this(name, baseModels, rateModels, false); } public ProductChainSubstitutionModel(String name, List<SubstitutionModel> baseModels, List<SiteRateModel> rateModels, boolean forceAverageModel) { super(name); this.baseModels = baseModels; this.rateModels = rateModels; this.forceAverageModel = forceAverageModel; numBaseModel = baseModels.size(); if (numBaseModel == 0) { throw new RuntimeException("May not construct ProductChainSubstitutionModel with 0 base models"); } if (rateModels != null) { for(SiteRateModel rateModel : rateModels) { if (rateModel.getCategoryCount() > 1) { throw new RuntimeException("ProductChainSubstitutionModels with multiple categories not yet implemented"); } } } List<FrequencyModel> freqModels = new ArrayList<FrequencyModel>(); stateSizes = new int[numBaseModel]; stateCount = 1; for (int i = 0; i < numBaseModel; i++) { freqModels.add(baseModels.get(i).getFrequencyModel()); DataType dataType = baseModels.get(i).getDataType(); stateSizes[i] = dataType.getStateCount(); stateCount *= dataType.getStateCount(); addModel(baseModels.get(i)); addModel(rateModels.get(i)); } pcFreqModel = new ProductChainFrequencyModel("pc",freqModels); addModel(pcFreqModel); String[] codeStrings = getCharacterStrings(); dataType = new GeneralDataType(codeStrings); updateMatrix = true; } @Override public Citation.Category getCategory() { return Citation.Category.SUBSTITUTION_MODELS; } @Override public String getDescription() { return "Product chain substitution model"; } public List<Citation> getCitations() { return Collections.singletonList(CommonCitations.OBRIEN_2009_LEARNING); } public EigenDecomposition getEigenDecomposition() { synchronized (this) { if (updateMatrix) { computeKroneckerSumsAndProducts(); } } return eigenDecomposition; } private String[] getCharacterStrings() { String[] strings = null; for (int i = numBaseModel - 1; i >= 0; i--) { strings = recursivelyAppendCharacterStates(baseModels.get(i).getDataType(), strings); } return strings; } protected void handleModelChangedEvent(Model model, Object object, int index) { super.handleModelChangedEvent(model, object, index); // Propogate change to higher models fireModelChanged(model); averageModel = null; } private String[] recursivelyAppendCharacterStates(DataType dataType, String[] inSubStates) { String[] subStates = inSubStates; if (subStates == null) { subStates = new String[]{""}; } final int previousStateCount = subStates.length; final int inStateCount = dataType.getStateCount(); String[] states = new String[previousStateCount * inStateCount]; for (int i = 0; i < inStateCount; i++) { String code = dataType.getCode(i); for (int j = 0; j < previousStateCount; j++) { states[i * previousStateCount + j] = code + subStates[j]; } } return states; } public void getInfinitesimalMatrix(double[] out) { getEigenDecomposition(); // Updates rate matrix if necessary System.arraycopy(rateMatrix, 0, out, 0, stateCount * stateCount); } double getRateForModel(int index) { if (!forceAverageModel) { return rateModels.get(index).getRateForCategory(0); } else { double total = 0.0; for (int i = 0; i < rateModels.size(); ++i) { total += rateModels.get(i).getRateForCategory(0); } return total / rateModels.size(); } } protected double[] scaleForProductChain(double[] in, int model) { if (rateModels == null) { return in; } final double scalar = getRateForModel(model); if (scalar == 1.0) { return in; } final int len = in.length; double[] out = new double[len]; for (int i = 0; i < len; i++) { out[i] = scalar * in[i]; } return out; } private SubstitutionProcess computeAverageModel() { return new SubstitutionProcess() { @Override public void getTransitionProbabilities(double distance, double[] matrix) { throw new RuntimeException("Should not be called"); } @Override public EigenDecomposition getEigenDecomposition() { if (eigenDecomposition == null) { // System.err.println("Statecount = " + stateSizes[0]); // System.exit(-1); double[][] mat = new double[stateSizes[0]][stateSizes[0]]; double[] vec = new double[stateSizes[0] * stateSizes[0]]; getInfinitesimalMatrix(vec); for (int i = 0; i < stateSizes[0]; ++i) { System.arraycopy(vec, i * stateSizes[0], mat[i], 0, stateSizes[0]); } eigenDecomposition = getDefaultEigenSystem(stateSizes[0]).decomposeMatrix(mat); } return eigenDecomposition; } @Override public FrequencyModel getFrequencyModel() { throw new RuntimeException("Should not be called"); } @Override public void getInfinitesimalMatrix(double[] matrix) { if (averageMatrix == null) { final int dim = matrix.length; averageMatrix = new double[dim]; double[][] allMatrices = new double[baseModels.size()][dim]; for (int i = 0; i < baseModels.size(); ++i) { baseModels.get(i).getInfinitesimalMatrix(allMatrices[i]); } for (int i = 0; i < dim; ++i) { double total = 0.0; for (int j = 0; j < baseModels.size(); ++j) { total += allMatrices[j][i]; } total /= baseModels.size(); averageMatrix[i] = total; } } // System.err.println("averageMatrix.length " + averageMatrix.length); // System.err.println("matrix.length " + matrix.length); System.arraycopy(averageMatrix, 0, matrix, 0, averageMatrix.length); } @Override public DataType getDataType() { throw new RuntimeException("Should not be called"); } @Override public boolean canReturnComplexDiagonalization() { throw new RuntimeException("Should not be called"); } private double[] averageMatrix = null; private EigenDecomposition eigenDecomposition = null; }; } // private SubstitutionProcess getAverageModel() { // if (!forceAverageModel) { // throw new RuntimeException("Error getting averaged model with non-averaged product chain"); // } // if (averageModel == null) { // averageModel = computeAverageModel(); // } // return averageModel; // } private SubstitutionProcess getBaseModel(int index) { if (!forceAverageModel) { return baseModels.get(index); } else { if (averageModel == null) { averageModel = computeAverageModel(); } return averageModel; } } private void computeKroneckerSumsAndProducts() { // if (forceAverageModel) { // if (averageModel == null) { // averageModel = computeAverageModel(); // } // eigenDecomposition = averageModel.getEigenDecomposition(); // if (rateMatrix == null) { // rateMatrix = new double[stateCount * stateCount]; // } // averageModel.getInfinitesimalMatrix(rateMatrix); // updateMatrix = false; // // System.err.println("lambda = " + new Vector(rateMatrix)); // // return; // } int currentStateSize = stateSizes[0]; double[] currentRate = new double[currentStateSize * currentStateSize]; getBaseModel(0).getInfinitesimalMatrix(currentRate); currentRate = scaleForProductChain(currentRate, 0); EigenDecomposition currentED = getBaseModel(0).getEigenDecomposition(); double[] currentEval = scaleForProductChain(currentED.getEigenValues(), 0); double[] currentEvec = currentED.getEigenVectors(); double[] currentIevcT = transpose(currentED.getInverseEigenVectors(), currentStateSize); for (int i = 1; i < numBaseModel; i++) { SubstitutionProcess nextModel = getBaseModel(i); int nextStateSize = stateSizes[i]; double[] nextRate = new double[nextStateSize * nextStateSize]; nextModel.getInfinitesimalMatrix(nextRate); nextRate = scaleForProductChain(nextRate, i); currentRate = KroneckerOperation.sum(currentRate, currentStateSize, nextRate, nextStateSize); EigenDecomposition nextED = nextModel.getEigenDecomposition(); double[] nextEval = scaleForProductChain(nextED.getEigenValues(), i); double[] nextEvec = nextED.getEigenVectors(); double[] nextIevcT = transpose(nextED.getInverseEigenVectors(), nextStateSize); currentEval = KroneckerOperation.sum(currentEval, nextEval); currentEvec = KroneckerOperation.product( currentEvec, currentStateSize, currentStateSize, nextEvec, nextStateSize, nextStateSize); currentIevcT = KroneckerOperation.product( currentIevcT, currentStateSize, currentStateSize, nextIevcT, nextStateSize, nextStateSize); currentStateSize *= nextStateSize; } rateMatrix = currentRate; eigenDecomposition = new EigenDecomposition( currentEvec, transpose(currentIevcT, currentStateSize), currentEval); updateMatrix = false; } // private static void printSquareMatrix(double[] A, int dim) { // double[] row = new double[dim]; // for (int i = 0; i < dim; i++) { // System.arraycopy(A, i * dim, row, 0, dim); // System.err.println(new Vector(row)); // } // } // transposes a square matrix private static double[] transpose(double[] mat, int dim) { double[] out = new double[dim * dim]; for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) { out[j * dim + i] = mat[i * dim + j]; } } return out; } public FrequencyModel getFrequencyModel() { return pcFreqModel; } protected void frequenciesChanged() { // Do nothing } protected void ratesChanged() { // Do nothing } protected void setupRelativeRates(double[] rates) { // Do nothing } protected final int numBaseModel; protected final List<SubstitutionModel> baseModels; protected final List<SiteRateModel> rateModels; protected final int[] stateSizes; protected final ProductChainFrequencyModel pcFreqModel; protected double[] rateMatrix = null; private final boolean forceAverageModel; private SubstitutionProcess averageModel = null; }