/* * SubstitutionModelDelegate.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.treelikelihood; import beagle.Beagle; import dr.evomodel.branchmodel.BranchModel; import dr.evomodel.substmodel.EigenDecomposition; import dr.evomodel.substmodel.SubstitutionModel; import dr.evolution.tree.Tree; import dr.evomodel.treedatalikelihood.BufferIndexHelper; import dr.util.Timer; import java.io.Serializable; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Deque; import java.util.List; /** * @author Andrew Rambaut * @author Filip Bielejec * @author Marc A. Suchard * @version $Id$ */ public final class SubstitutionModelDelegate implements EvolutionaryProcessDelegate, Serializable { private static final boolean DEBUG = false; private static final boolean RUN_IN_SERIES = false; public static final boolean MEASURE_RUN_TIME = true; public double updateTime; public double convolveTime; private static final int BUFFER_POOL_SIZE_DEFAULT = 100; private final Tree tree; private final List<SubstitutionModel> substitutionModelList; private final BranchModel branchModel; private final int eigenCount; private final int nodeCount; private final int extraBufferCount; private final int reserveBufferIndex; private final BufferIndexHelper eigenBufferHelper; private BufferIndexHelper matrixBufferHelper; private Deque<Integer> availableBuffers = new ArrayDeque<Integer>(); public SubstitutionModelDelegate(Tree tree, BranchModel branchModel) { this(tree, branchModel, BUFFER_POOL_SIZE_DEFAULT); } public SubstitutionModelDelegate(Tree tree, BranchModel branchModel, int bufferPoolSize) { if (MEASURE_RUN_TIME) { updateTime = 0; convolveTime = 0; } this.tree = tree; this.substitutionModelList = branchModel.getSubstitutionModels(); this.branchModel = branchModel; eigenCount = substitutionModelList.size(); nodeCount = tree.getNodeCount(); // two eigen buffers for each decomposition for store and restore. eigenBufferHelper = new BufferIndexHelper(eigenCount, 0); // two matrices for each node less the root matrixBufferHelper = new BufferIndexHelper(nodeCount, 0); this.extraBufferCount = branchModel.requiresMatrixConvolution() ? (bufferPoolSize > 0 ? bufferPoolSize : BUFFER_POOL_SIZE_DEFAULT) : 0; if (branchModel.requiresMatrixConvolution() && this.extraBufferCount < eigenCount) { throw new RuntimeException("SubstitutionModelDelegate requires at least " + eigenCount + " extra buffers to convolve matrices"); } for (int i = 0; i < extraBufferCount; i++) { pushAvailableBuffer(i + matrixBufferHelper.getBufferCount()); } // one extra created as a reserve // which is used to free up buffers when the avail stack is empty. reserveBufferIndex = matrixBufferHelper.getBufferCount() + extraBufferCount; if (DEBUG) { System.out.println("Creating reserve buffer with index: " + reserveBufferIndex); } }// END: Constructor public boolean canReturnComplexDiagonalization() { for (SubstitutionModel model : substitutionModelList) { if (model.canReturnComplexDiagonalization()) { return true; } } return false; } public int getEigenBufferCount() { return eigenBufferHelper.getBufferCount(); } public int getMatrixBufferCount() { // plus one for the reserve buffer return matrixBufferHelper.getBufferCount() + extraBufferCount + 1; } public int getSubstitutionModelCount() { return substitutionModelList.size(); } public SubstitutionModel getSubstitutionModel(int index) { return substitutionModelList.get(index); } public void updateSubstitutionModels(Beagle beagle) { for (int i = 0; i < eigenCount; i++) { eigenBufferHelper.flipOffset(i); EigenDecomposition ed = substitutionModelList.get(i).getEigenDecomposition(); beagle.setEigenDecomposition( eigenBufferHelper.getOffsetIndex(i), ed.getEigenVectors(), ed.getInverseEigenVectors(), ed.getEigenValues()); } } public void updateTransitionMatrices(Beagle beagle, int[] branchIndices, double[] edgeLength, int updateCount) { int[][] probabilityIndices = new int[eigenCount][updateCount]; double[][] edgeLengths = new double[eigenCount][updateCount]; int[] counts = new int[eigenCount]; List<Deque<Integer>> convolutionList = new ArrayList<Deque<Integer>>(); for (int i = 0; i < updateCount; i++) { BranchModel.Mapping mapping = branchModel.getBranchModelMapping(tree.getNode(branchIndices[i])); int[] order = mapping.getOrder(); double[] weights = mapping.getWeights(); if (order.length == 1) { int k = order[0]; probabilityIndices[k][counts[k]] = matrixBufferHelper.getOffsetIndex(branchIndices[i]); edgeLengths[k][counts[k]] = edgeLength[i]; counts[k]++; } else { double sum = 0.0; for (double w : weights) { sum += w; } if (getAvailableBufferCount() < order.length) { // too few buffers available, process what we have and continue... computeTransitionMatrices(beagle, probabilityIndices, edgeLengths, counts); convolveMatrices(beagle, convolutionList); // reset the counts for (int k = 0; k < eigenCount; k++) { counts[k] = 0; } } Deque<Integer> bufferIndices = new ArrayDeque<Integer>(); for (int j = 0; j < order.length; j++) { int buffer = popAvailableBuffer(); if (buffer < 0) { // no buffers available throw new RuntimeException("Ran out of buffers for transition matrices - computing current list."); } int k = order[j]; probabilityIndices[k][counts[k]] = buffer; edgeLengths[k][counts[k]] = weights[j] * edgeLength[i] / sum; // edgeLengths[k][counts[k]] = weights[j] ; counts[k]++; bufferIndices.add(buffer); } bufferIndices.add(matrixBufferHelper.getOffsetIndex(branchIndices[i])); convolutionList.add(bufferIndices); }// END: if convolution needed }// END: i loop computeTransitionMatrices(beagle, probabilityIndices, edgeLengths, counts); convolveMatrices(beagle, convolutionList); }// END: updateTransitionMatrices private void computeTransitionMatrices(Beagle beagle, int[][] probabilityIndices, double[][] edgeLengths, int[] counts) { Timer timer; if (MEASURE_RUN_TIME) { timer = new Timer(); timer.start(); } if (DEBUG) { System.out.print("Computing matrices:"); } for (int i = 0; i < eigenCount; i++) { if (DEBUG) { for (int j = 0; j < counts[i]; j++) { // System.out.print(" " + probabilityIndices[i][j]); System.out.print(" " + probabilityIndices[i][j] + " (" + edgeLengths[i][j] + ")"); } } if (counts[i] > 0) { beagle.updateTransitionMatrices(eigenBufferHelper.getOffsetIndex(i), probabilityIndices[i], null, // firstDerivativeIndices null, // secondDerivativeIndices edgeLengths[i], counts[i]); } } if (DEBUG) { System.out.println(); } if (MEASURE_RUN_TIME) { timer.stop(); double timeInSeconds = timer.toSeconds(); updateTime += timeInSeconds; } }//END: computeTransitionMatrices private void convolveMatrices(Beagle beagle, List<Deque<Integer>> convolutionList) { Timer timer; if (MEASURE_RUN_TIME) { timer = new Timer(); timer.start(); } while (convolutionList.size() > 0) { int[] firstConvolutionBuffers = new int[nodeCount]; int[] secondConvolutionBuffers = new int[nodeCount]; int[] resultConvolutionBuffers = new int[nodeCount]; int operationsCount = 0; List<Deque<Integer>> empty = new ArrayList<Deque<Integer>>(); for (Deque<Integer> convolve : convolutionList) { if (convolve.size() > 3) { firstConvolutionBuffers[operationsCount] = convolve.pop(); secondConvolutionBuffers[operationsCount] = convolve.pop(); int buffer; boolean done; do { done = true; buffer = popAvailableBuffer(); if (buffer < 0) { // no buffers available // throw new RuntimeException("All out of buffers"); // we have run out of buffers, process what we have and continue... if (DEBUG) { System.out.println("Ran out of buffers for convolving - computing current list."); System.out.print("Convolving " + operationsCount + " matrices:"); for (int i = 0; i < operationsCount; i++) { System.out.print(" " + firstConvolutionBuffers[i] + "*" + secondConvolutionBuffers[i] + "->" + resultConvolutionBuffers[i]); } System.out.println(); } if (operationsCount > 0) { convolveAndRelease(beagle, firstConvolutionBuffers, secondConvolutionBuffers, resultConvolutionBuffers, operationsCount); // copy the uncompleted operation back down to the beginning of the operations list firstConvolutionBuffers[0] = firstConvolutionBuffers[operationsCount]; secondConvolutionBuffers[0] = secondConvolutionBuffers[operationsCount]; // reset the operation count operationsCount = 0; done = false; // there should be enough spare buffers to get a resultConvolutionBuffer for this operation now } else { // only one partially setup operation so there would be none to free up // in this case we will use the reserve buffer resultConvolutionBuffers[operationsCount] = getReserveBuffer(); convolveAndRelease(beagle, firstConvolutionBuffers, secondConvolutionBuffers, resultConvolutionBuffers, 1); convolve.push(getReserveBuffer()); done = true; // break out of the do loop } } } while (!done); if (buffer >= 0) { // if the buffer is still negative then the loop above will have used the reserve buffer // to complete the convolution. resultConvolutionBuffers[operationsCount] = buffer; convolve.push(buffer); operationsCount++; } } else if (convolve.size() == 3) { firstConvolutionBuffers[operationsCount] = convolve.pop(); secondConvolutionBuffers[operationsCount] = convolve.pop(); resultConvolutionBuffers[operationsCount] = convolve.pop(); operationsCount++; } else { throw new RuntimeException("Unexpected convolve list size"); } if (convolve.size() == 0) { empty.add(convolve); } } if (DEBUG) { System.out.print("Convolving " + operationsCount+ " matrices:"); for (int i = 0; i < operationsCount; i++) { System.out.print(" " + firstConvolutionBuffers[i] + "*" + secondConvolutionBuffers[i] + "->" + resultConvolutionBuffers[i]); } System.out.println(); } convolveAndRelease(beagle, firstConvolutionBuffers, secondConvolutionBuffers, resultConvolutionBuffers, operationsCount); convolutionList.removeAll(empty); } if (MEASURE_RUN_TIME) { timer.stop(); double timeInSeconds = timer.toSeconds(); convolveTime += timeInSeconds; } }// END: convolveTransitionMatrices private void convolveAndRelease(Beagle beagle, int[] firstConvolutionBuffers, int[] secondConvolutionBuffers, int[] resultConvolutionBuffers, int operationsCount) { if (RUN_IN_SERIES) { if (operationsCount > 1) { throw new RuntimeException("Unable to convolve matrices in series"); } } beagle.convolveTransitionMatrices(firstConvolutionBuffers, // A secondConvolutionBuffers, // B resultConvolutionBuffers, // C operationsCount // count ); for (int i = 0; i < operationsCount; i++) { if (firstConvolutionBuffers[i] >= matrixBufferHelper.getBufferCount() && firstConvolutionBuffers[i] != reserveBufferIndex) { pushAvailableBuffer(firstConvolutionBuffers[i]); } if (secondConvolutionBuffers[i] >= matrixBufferHelper.getBufferCount() && secondConvolutionBuffers[i] != reserveBufferIndex) { pushAvailableBuffer(secondConvolutionBuffers[i]); } } }//END: convolveAndRelease private int getAvailableBufferCount() { if (RUN_IN_SERIES) { return 0; } else { return availableBuffers.size(); } } private int popAvailableBuffer() { if (availableBuffers.isEmpty()) { return -1; } return availableBuffers.pop(); } /** * the reserve buffer is one extra buffer used to free up some spare buffers * @return */ private int getReserveBuffer() { return reserveBufferIndex; } private void pushAvailableBuffer(int index) { availableBuffers.push(index); } public double[] getRootStateFrequencies() { return branchModel.getRootFrequencyModel().getFrequencies(); }// END: getStateFrequencies public void flipMatrixBuffer(int branchIndex) { matrixBufferHelper.flipOffset(branchIndex); } public int getMatrixIndex(int branchIndex) { return matrixBufferHelper.getOffsetIndex(branchIndex); } public void storeState() { eigenBufferHelper.storeState(); matrixBufferHelper.storeState(); } public void restoreState() { eigenBufferHelper.restoreState(); matrixBufferHelper.restoreState(); } }// END: class