/* * MultiPartitionDataLikelihoodDelegate.java * * Copyright (c) 2002-2017 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.treedatalikelihood; /** * MultiPartitionDataLikelihoodDelegate * * A DataLikelihoodDelegate that uses BEAGLE 3 to allow for parallelization across multiple data partitions * * @author Andrew Rambaut * @author Marc Suchard * @author Guy Baele * @version $Id$ */ import beagle.*; import dr.evolution.alignment.PatternList; import dr.evolution.datatype.DataType; import dr.evolution.tree.Tree; import dr.evolution.util.TaxonList; import dr.evomodel.branchmodel.BranchModel; import dr.evomodel.siteratemodel.SiteRateModel; import dr.evomodel.treelikelihood.PartialsRescalingScheme; import dr.inference.model.AbstractModel; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.util.Citable; import dr.util.Citation; import dr.util.CommonCitations; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.logging.Logger; public class MultiPartitionDataLikelihoodDelegate extends AbstractModel implements DataLikelihoodDelegate, Citable { private static final boolean RESCALING_OFF = false; // a debugging switch private static final boolean DEBUG = false; public static boolean IS_MULTI_PARTITION_COMPATIBLE() { return BeagleInfo.getVersionNumbers()[0] >= 3; } // This property is a comma-delimited list of resource numbers (0 == CPU) to // allocate each BEAGLE instance to. If less than the number of instances then // will wrap around. private static final String RESOURCE_ORDER_PROPERTY = "beagle.resource.order"; private static final String PREFERRED_FLAGS_PROPERTY = "beagle.preferred.flags"; private static final String REQUIRED_FLAGS_PROPERTY = "beagle.required.flags"; private static final String SCALING_PROPERTY = "beagle.scaling"; private static final String RESCALE_FREQUENCY_PROPERTY = "beagle.rescale"; private static final String DELAY_SCALING_PROPERTY = "beagle.delay.scaling"; private static final String EXTRA_BUFFER_COUNT_PROPERTY = "beagle.extra.buffer.count"; private static final String FORCE_VECTORIZATION = "beagle.force.vectorization"; // Which scheme to use if choice not specified (or 'default' is selected): private static final PartialsRescalingScheme DEFAULT_RESCALING_SCHEME = PartialsRescalingScheme.DYNAMIC; private static int instanceCount = 0; private static List<Integer> resourceOrder = null; private static List<Integer> preferredOrder = null; private static List<Integer> requiredOrder = null; private static List<String> scalingOrder = null; private static List<Integer> extraBufferOrder = null; // Default frequency for complete recomputation of scaling factors under the 'dynamic' scheme private static final int RESCALE_FREQUENCY = 100; private static final int RESCALE_TIMES = 1; /** * Construct an instance using a list of PatternLists, one for each partition. The * partitions will share a tree but can have different branchModels and siteRateModels * The latter should either have a size of 1 (in which case they are shared across partitions) * or equal to patternLists.size() where each partition has a different model. * * @param tree Used for configuration - shouldn't be watched for changes * @param branchModels Specifies a list of branch models for each partition * @param patternLists List of patternLists comprising each partition * @param siteRateModels A list of siteRateModels for each partition * @param useAmbiguities Whether to respect state ambiguities in data */ public MultiPartitionDataLikelihoodDelegate(Tree tree, List<PatternList> patternLists, List<BranchModel> branchModels, List<SiteRateModel> siteRateModels, boolean useAmbiguities, PartialsRescalingScheme rescalingScheme, boolean delayRescalingUntilUnderflow) { super("MultiPartitionDataLikelihoodDelegate"); final Logger logger = Logger.getLogger("dr.evomodel"); logger.info("\nUsing Multi-Partition Data Likelihood Delegate with BEAGLE 3 extensions"); setId(patternLists.get(0).getId()); this.patternLists = patternLists; this.dataType = patternLists.get(0).getDataType(); stateCount = dataType.getStateCount(); partitionCount = patternLists.size(); patternCounts = new int[partitionCount]; int total = 0; int k = 0; for (PatternList patternList : patternLists) { assert(patternList.getDataType().equals(this.dataType)); patternCounts[k] = patternList.getPatternCount(); total += patternCounts[k]; k++; } totalPatternCount = total; useScaleFactors = new boolean[partitionCount]; recomputeScaleFactors = new boolean[partitionCount]; everUnderflowed = new boolean[partitionCount]; flip = new boolean[partitionCount]; for (int i = 0; i < partitionCount; i++) { flip[i] = true; } updatePartition = new boolean[partitionCount]; partitionWasUpdated = new boolean[partitionCount]; updateAllPartitions = true; cachedLogLikelihoodsByPartition = new double[partitionCount]; storedCachedLogLikelihoodsByPartition = new double[partitionCount]; // Branch models determine the substitution models per branch. There can be either // one per partition or one shared across all partitions assert(branchModels.size() == 1 || (branchModels.size() == patternLists.size())); this.branchModels.addAll(branchModels); for (BranchModel branchModel : this.branchModels) { addModel(branchModel); } // SiteRateModels determine the rates per category (for site-heterogeneity models). // There can be either one per partition or one shared across all partitions assert(siteRateModels.size() == 1 || (siteRateModels.size() == patternLists.size())); this.siteRateModels.addAll(siteRateModels); this.categoryCount = this.siteRateModels.get(0).getCategoryCount(); for (SiteRateModel siteRateModel : this.siteRateModels) { assert(siteRateModel.getCategoryCount() == categoryCount); addModel(siteRateModel); } nodeCount = tree.getNodeCount(); tipCount = tree.getExternalNodeCount(); internalNodeCount = nodeCount - tipCount; branchUpdateIndices = new int[nodeCount]; branchLengths = new double[nodeCount]; //changed initialization to account for multiple partitions scaleBufferIndices = new int[partitionCount][internalNodeCount]; storedScaleBufferIndices = new int[partitionCount][internalNodeCount]; operations = new int[internalNodeCount * Beagle.PARTITION_OPERATION_TUPLE_SIZE * partitionCount]; rescalingCount = new int[partitionCount]; rescalingCountInner = new int[partitionCount]; firstRescaleAttempt = true; try { int compactPartialsCount = tipCount; if (useAmbiguities) { // if we are using ambiguities then we don't use tip partials compactPartialsCount = 0; } partialBufferHelper = new BufferIndexHelper[partitionCount]; scaleBufferHelper = new BufferIndexHelper[partitionCount]; categoryRateBufferHelper = new BufferIndexHelper[partitionCount]; for (int i = 0; i < partitionCount; i++) { // one partials buffer for each tip and two for each internal node (for store restore) partialBufferHelper[i] = new BufferIndexHelper(nodeCount, tipCount); // one scaling buffer for each internal node plus an extra for the accumulation, then doubled for store/restore scaleBufferHelper[i] = new BufferIndexHelper(getScaleBufferCount(), 0); categoryRateBufferHelper[i] = new BufferIndexHelper(1, 0); } int eigenBufferCount = 0; int matrixBufferCount = 0; // create a substitutionModelDelegate for each branchModel int partitionNumber = 0; for (BranchModel branchModel : this.branchModels) { HomogenousSubstitutionModelDelegate substitutionModelDelegate = new HomogenousSubstitutionModelDelegate(tree, branchModel, partitionNumber); evolutionaryProcessDelegates.add(substitutionModelDelegate); eigenBufferCount += substitutionModelDelegate.getEigenBufferCount(); matrixBufferCount += substitutionModelDelegate.getMatrixBufferCount(); partitionNumber ++; } // Attempt to get the resource order from the System Property if (resourceOrder == null) { resourceOrder = parseSystemPropertyIntegerArray(RESOURCE_ORDER_PROPERTY); } if (preferredOrder == null) { preferredOrder = parseSystemPropertyIntegerArray(PREFERRED_FLAGS_PROPERTY); } if (requiredOrder == null) { requiredOrder = parseSystemPropertyIntegerArray(REQUIRED_FLAGS_PROPERTY); } if (scalingOrder == null) { scalingOrder = parseSystemPropertyStringArray(SCALING_PROPERTY); } if (extraBufferOrder == null) { extraBufferOrder = parseSystemPropertyIntegerArray(EXTRA_BUFFER_COUNT_PROPERTY); } // first set the rescaling scheme to use from the parser this.rescalingScheme = rescalingScheme; this.delayRescalingUntilUnderflow = delayRescalingUntilUnderflow; int[] resourceList = null; long preferenceFlags = 0; long requirementFlags = 0; if (scalingOrder.size() > 0) { this.rescalingScheme = PartialsRescalingScheme.parseFromString( scalingOrder.get(instanceCount % scalingOrder.size())); } if (resourceOrder.size() > 0) { // added the zero on the end so that a CPU is selected if requested resource fails resourceList = new int[]{resourceOrder.get(instanceCount % resourceOrder.size()), 0}; if (resourceList[0] > 0) { preferenceFlags |= BeagleFlag.PROCESSOR_GPU.getMask(); // Add preference weight against CPU } } if (preferredOrder.size() > 0) { preferenceFlags = preferredOrder.get(instanceCount % preferredOrder.size()); } if (requiredOrder.size() > 0) { requirementFlags = requiredOrder.get(instanceCount % requiredOrder.size()); } // Define default behaviour here if (this.rescalingScheme == PartialsRescalingScheme.DEFAULT) { //if GPU: the default is dynamic scaling in BEAST if (resourceList != null && resourceList[0] > 1) { this.rescalingScheme = DEFAULT_RESCALING_SCHEME; } else { // if CPU: just run as fast as possible // this.rescalingScheme = PartialsRescalingScheme.NONE; // Dynamic should run as fast as none until first underflow this.rescalingScheme = DEFAULT_RESCALING_SCHEME; } } // to keep behaviour of the delayed scheme (always + delay)... if (this.rescalingScheme == PartialsRescalingScheme.DELAYED) { this.delayRescalingUntilUnderflow = true; this.rescalingScheme = PartialsRescalingScheme.ALWAYS; } if (this.rescalingScheme == PartialsRescalingScheme.AUTO) { // auto scaling not supported for multi-partition BEAGLE3 instances preferenceFlags |= BeagleFlag.SCALING_DYNAMIC.getMask(); } else { // preferenceFlags |= BeagleFlag.SCALING_MANUAL.getMask(); } String r = System.getProperty(RESCALE_FREQUENCY_PROPERTY); if (r != null) { rescalingFrequency = Integer.parseInt(r); if (rescalingFrequency < 1) { rescalingFrequency = RESCALE_FREQUENCY; } } String d = System.getProperty(DELAY_SCALING_PROPERTY); if (d != null) { this.delayRescalingUntilUnderflow = Boolean.parseBoolean(d); } // I don't think this performance stuff should be here. Perhaps have an intelligent automatic // load balancer further up the chain. // if (preferenceFlags == 0 && resourceList == null) { // else determine dataset characteristics // if (stateCount == 4 && patternList.getPatternCount() < 10000) // TODO determine good cut-off // preferenceFlags |= BeagleFlag.PROCESSOR_CPU.getMask(); // } boolean forceVectorization = false; String vectorizationString = System.getProperty(FORCE_VECTORIZATION); if (vectorizationString != null) { forceVectorization = true; } if (BeagleFlag.VECTOR_SSE.isSet(preferenceFlags) && (stateCount != 4) && !forceVectorization ) { // @todo SSE doesn't seem to work for larger state spaces so for now we override the // SSE option. preferenceFlags &= ~BeagleFlag.VECTOR_SSE.getMask(); preferenceFlags |= BeagleFlag.VECTOR_NONE.getMask(); if (stateCount > 4 && this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) { this.rescalingScheme = PartialsRescalingScheme.DELAYED; } } if (!BeagleFlag.PRECISION_SINGLE.isSet(preferenceFlags)) { // if single precision not explicitly set then prefer double preferenceFlags |= BeagleFlag.PRECISION_DOUBLE.getMask(); } if (evolutionaryProcessDelegates.get(0).canReturnComplexDiagonalization()) { requirementFlags |= BeagleFlag.EIGEN_COMPLEX.getMask(); } //TODO: check getBufferCount() calls with Daniel //TODO: should we multiple getBufferCount() by the number of partitions? beagle = BeagleFactory.loadBeagleInstance( tipCount, partialBufferHelper[0].getBufferCount(), compactPartialsCount, stateCount, totalPatternCount, eigenBufferCount, matrixBufferCount, categoryCount, scaleBufferHelper[0].getBufferCount(), // Always allocate; they may become necessary resourceList, preferenceFlags, requirementFlags ); InstanceDetails instanceDetails = beagle.getDetails(); ResourceDetails resourceDetails = null; if (instanceDetails != null) { resourceDetails = BeagleFactory.getResourceDetails(instanceDetails.getResourceNumber()); if (resourceDetails != null) { StringBuilder sb = new StringBuilder(" Using BEAGLE resource "); sb.append(resourceDetails.getNumber()).append(": "); sb.append(resourceDetails.getName()).append("\n"); if (resourceDetails.getDescription() != null) { String[] description = resourceDetails.getDescription().split("\\|"); for (String desc : description) { if (desc.trim().length() > 0) { sb.append(" ").append(desc.trim()).append("\n"); } } } sb.append(" with instance flags: ").append(instanceDetails.toString()); logger.info(sb.toString()); } else { logger.info(" Error retrieving BEAGLE resource for instance: " + instanceDetails.toString()); } } else { logger.info(" No external BEAGLE resources available, or resource list/requirements not met, using Java implementation"); } patternPartitions = new int[totalPatternCount]; patternWeights = new double[totalPatternCount]; int j = 0; k = 0; for (PatternList patternList : patternLists) { double[] pw = patternList.getPatternWeights(); for (int i = 0; i < patternList.getPatternCount(); i++) { patternPartitions[k] = j; patternWeights[k] = pw[i]; k++; } j++; } logger.info(" " + (useAmbiguities ? "Using" : "Ignoring") + " ambiguities in tree likelihood."); String patternCountString = "" + patternLists.get(0).getPatternCount(); for (int i = 1; i < patternLists.size(); i++) { patternCountString += ", " + patternLists.get(i).getPatternCount(); } logger.info(" With " + patternLists.size() + " partitions comprising " + patternCountString + " unique site patterns"); // @todo - should check that each patternList spans the same set of taxa for (int i = 0; i < tipCount; i++) { String id = tree.getTaxonId(i); if (useAmbiguities) { setPartials(beagle, patternLists, id, i); } else { setStates(beagle, patternLists, id, i); } } beagle.setPatternWeights(patternWeights); beagle.setPatternPartitions(partitionCount, patternPartitions); String rescaleMessage = " Using rescaling scheme : " + this.rescalingScheme.getText(); if (this.rescalingScheme == PartialsRescalingScheme.AUTO) { // auto scaling in BEAGLE3 is not supported this.rescalingScheme = PartialsRescalingScheme.DYNAMIC; rescaleMessage = " Auto rescaling not supported in BEAGLE v3, using : " + this.rescalingScheme.getText(); } boolean parenthesis = false; if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) { rescaleMessage += " (rescaling every " + rescalingFrequency + " evaluations"; parenthesis = true; } if (this.delayRescalingUntilUnderflow) { rescaleMessage += (parenthesis ? ", " : "(") + "delay rescaling until first overflow"; parenthesis = true; } rescaleMessage += (parenthesis ? ")" : ""); logger.info(rescaleMessage); if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) { for (int i = 0; i < partitionCount; i++) { everUnderflowed[i] = false; // If false, BEAST does not rescale until first under-/over-flow. } } updateSubstitutionModels = new boolean[branchModels.size()]; updateSubstitutionModels(); updateSiteRateModels = new boolean[siteRateModels.size()]; updateSiteRateModels(); } catch (TaxonList.MissingTaxonException mte) { throw new RuntimeException(mte.toString()); } } @Override public String getReport() { return null; } @Override public TreeTraversal.TraversalType getOptimalTraversalType() { return TreeTraversal.TraversalType.REVERSE_LEVEL_ORDER; } public List<PatternList> getPatternLists() { return this.patternLists; } @Override public int getTraitCount() { return 1; } @Override public int getTraitDim() { return totalPatternCount; } private void updateSubstitutionModels(boolean... state) { for (int i = 0; i < updateSubstitutionModels.length; i++) { updateSubstitutionModels[i] = (state.length < 1 || state[0]); } } private void updateSubstitutionModel(BranchModel branchModel) { for (int i = 0; i < branchModels.size(); i++) { if (branchModels.get(i) == branchModel) { updateSubstitutionModels[i] = true; } } } private void updateSiteRateModels(boolean... state) { for (int i = 0; i < updateSiteRateModels.length; i++) { updateSiteRateModels[i] = (state.length < 1 || state[0]); } } private void updateSiteRateModel(SiteRateModel siteRateModel) { for (int i = 0; i < siteRateModels.size(); i++) { if (siteRateModels.get(i) == siteRateModel) { updateSiteRateModels[i] = true; } } } private static List<Integer> parseSystemPropertyIntegerArray(String propertyName) { List<Integer> order = new ArrayList<Integer>(); String r = System.getProperty(propertyName); if (r != null) { String[] parts = r.split(","); for (String part : parts) { try { int n = Integer.parseInt(part.trim()); order.add(n); } catch (NumberFormatException nfe) { System.err.println("Invalid entry '" + part + "' in " + propertyName); } } } return order; } private static List<String> parseSystemPropertyStringArray(String propertyName) { List<String> order = new ArrayList<String>(); String r = System.getProperty(propertyName); if (r != null) { String[] parts = r.split(","); for (String part : parts) { try { String s = part.trim(); order.add(s); } catch (NumberFormatException nfe) { System.err.println("Invalid entry '" + part + "' in " + propertyName); } } } return order; } private int getScaleBufferCount() { return internalNodeCount + 1; } /** * Sets the partials from a sequence in an alignment. * * @param beagle beagle * @param patternLists patternLists * @param taxonId taxonId * @param nodeIndex nodeIndex */ private final void setPartials(Beagle beagle, List<PatternList> patternLists, String taxonId, int nodeIndex) throws TaxonList.MissingTaxonException { double[] partials = new double[totalPatternCount * stateCount * categoryCount]; int v = 0; for (PatternList patternList : patternLists) { int sequenceIndex = patternList.getTaxonIndex(taxonId); if (sequenceIndex == -1) { throw new TaxonList.MissingTaxonException("Taxon, " + taxonId + ", not found in patternList, " + patternList.getId()); } boolean[] stateSet; for (int i = 0; i < patternList.getPatternCount(); i++) { int state = patternList.getPatternState(sequenceIndex, i); stateSet = dataType.getStateSet(state); for (int j = 0; j < stateCount; j++) { if (stateSet[j]) { partials[v] = 1.0; } else { partials[v] = 0.0; } v++; } } } // if there is more than one category then replicate the partials for each int n = totalPatternCount * stateCount; int k = n; for (int i = 1; i < categoryCount; i++) { System.arraycopy(partials, 0, partials, k, n); k += n; } beagle.setPartials(nodeIndex, partials); } /** * Sets the partials from a sequence in an alignment. * * @param beagle beagle * @param patternLists patternLists * @param taxonId taxonId * @param nodeIndex nodeIndex */ private final void setStates(Beagle beagle, List<PatternList> patternLists, String taxonId, int nodeIndex) throws TaxonList.MissingTaxonException { int[] states = new int[totalPatternCount]; int v = 0; for (PatternList patternList : patternLists) { int sequenceIndex = patternList.getTaxonIndex(taxonId); if (sequenceIndex == -1) { throw new TaxonList.MissingTaxonException("Taxon, " + taxonId + ", not found in patternList, " + patternList.getId()); } for (int i = 0; i < patternList.getPatternCount(); i++) { states[v] = patternList.getPatternState(sequenceIndex, i); v++; } } beagle.setTipStates(nodeIndex, states); } /** * Calculate the log likelihood of the current state. * * @return the log likelihood. */ @Override public double calculateLikelihood(List<BranchOperation> branchOperations, List<NodeOperation> nodeOperations, int rootNodeNumber) throws LikelihoodException { boolean throwLikelihoodRescalingException = false; if (!initialEvaluation) { for (int i = 0; i < partitionCount; i++) { if (!this.delayRescalingUntilUnderflow || everUnderflowed[i]) { if (this.rescalingScheme == PartialsRescalingScheme.ALWAYS || this.rescalingScheme == PartialsRescalingScheme.DELAYED) { useScaleFactors[i] = true; recomputeScaleFactors[i] = true; } else if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) { useScaleFactors[i] = true; if (DEBUG) { System.out.println("rescalingCount["+i+"] = " + rescalingCount[i]); } if (rescalingCount[i] > rescalingFrequency) { if (DEBUG) { System.out.println("rescalingCount > rescalingFrequency"); } rescalingCount[i] = 0; rescalingCountInner[i] = 0; } if (DEBUG) { System.out.println("rescalingCountInner = " + rescalingCountInner[i]); } if (rescalingCountInner[i] < RESCALE_TIMES) { if (DEBUG) { System.out.println("rescalingCountInner < RESCALE_TIMES"); } recomputeScaleFactors[i] = true; updatePartition[i] = true; rescalingCountInner[i]++; throwLikelihoodRescalingException = true; } } } } if (throwLikelihoodRescalingException) { throw new LikelihoodRescalingException(); } } if (RESCALING_OFF) { // a debugging switch for (int i = 0; i < partitionCount; i++) { useScaleFactors[i] = false; recomputeScaleFactors[i] = false; } } int k = 0; for (EvolutionaryProcessDelegate evolutionaryProcessDelegate : evolutionaryProcessDelegates) { if (updateSubstitutionModels[k]) { // TODO: More efficient to update only the substitution model that changed, instead of all // TODO: flip currently assumes 1 substitution model per partition evolutionaryProcessDelegate.updateSubstitutionModels(beagle, flip[k]); updatePartition[k] = true; if (DEBUG) { System.out.println("updateSubstitutionModels, updatePartition["+k+"] = " + updatePartition[k]); } updateAllPartitions = false; // we are currently assuming a no-category model... } k++; } k = 0; for (SiteRateModel siteRateModel : siteRateModels) { if (updateSiteRateModels[k]) { double[] categoryRates = siteRateModel.getCategoryRates(); if (categoryRates == null) { // If this returns null then there was a numerical error calculating the category rates // (probably a very small alpha) so reject the move. return Double.NEGATIVE_INFINITY; } if (flip[k]) { categoryRateBufferHelper[k].flipOffset(0); } beagle.setCategoryRatesWithIndex(categoryRateBufferHelper[k].getOffsetIndex(0), categoryRates); updatePartition[k] = true; if (DEBUG) { System.out.println("updateSiteRateModels, updatePartition["+k+"] = " + updatePartition[k]); } updateAllPartitions = false; } k++; } int branchUpdateCount = 0; for (BranchOperation op : branchOperations) { branchUpdateIndices[branchUpdateCount] = op.getBranchNumber(); branchLengths[branchUpdateCount] = op.getBranchLength(); branchUpdateCount++; } if (branchUpdateCount > 0) { // TODO below only applies to homogenous substitution models int [] eigenDecompositionIndices = new int [branchUpdateCount * partitionCount]; int [] categoryRateIndices = new int [branchUpdateCount * partitionCount]; int [] probabilityIndices = new int [branchUpdateCount * partitionCount]; double[] edgeLengths = new double[branchUpdateCount * partitionCount]; int op = 0; int partition = 0; for (EvolutionaryProcessDelegate evolutionaryProcessDelegate : evolutionaryProcessDelegates) { if (updatePartition[partition] || updateAllPartitions) { if (flip[partition]) { evolutionaryProcessDelegate.flipTransitionMatrices(branchUpdateIndices, branchUpdateCount); } for (int i = 0; i < branchUpdateCount; i++) { eigenDecompositionIndices[op] = evolutionaryProcessDelegate.getEigenIndex(0); categoryRateIndices[op] = categoryRateBufferHelper[partition].getOffsetIndex(0); probabilityIndices[op] = evolutionaryProcessDelegate.getMatrixIndex(branchUpdateIndices[i]); edgeLengths[op] = branchLengths[i]; op++; } } partition++; } beagle.updateTransitionMatricesWithMultipleModels( eigenDecompositionIndices, categoryRateIndices, probabilityIndices, null, // firstDerivativeIndices null, // secondDerivativeIndices edgeLengths, op); } for (int i = 0; i < partitionCount; i++) { if (updatePartition[i] || updateAllPartitions) { if (DEBUG) { System.out.println("updatePartition["+i+"] = " + updatePartition[i] + ", updateAllPartitions = " + updateAllPartitions); } if (flip[i]) { // Flip all the buffers to be written to first... for (NodeOperation op : nodeOperations) { partialBufferHelper[i].flipOffset(op.getNodeNumber()); } } } } int operationCount = 0; k = 0; for (NodeOperation op : nodeOperations) { int nodeNum = op.getNodeNumber(); int[] writeScale = new int[partitionCount]; int[] readScale = new int[partitionCount]; for (int i = 0; i < partitionCount; i++) { if (updatePartition[i] || updateAllPartitions) { if (useScaleFactors[i]) { // get the index of this scaling buffer int n = nodeNum - tipCount; if (recomputeScaleFactors[i]) { // flip the indicator: can take either n or (internalNodeCount + 1) - n scaleBufferHelper[i].flipOffset(n); // store the index scaleBufferIndices[i][n] = scaleBufferHelper[i].getOffsetIndex(n); writeScale[i] = scaleBufferIndices[i][n]; // Write new scaleFactor readScale[i] = Beagle.NONE; } else { writeScale[i] = Beagle.NONE; readScale[i] = scaleBufferIndices[i][n]; // Read existing scaleFactor } } else { writeScale[i] = Beagle.NONE; // Not using scaleFactors readScale[i] = Beagle.NONE; } } } //Example 1: 1 partition with 1 evolutionary model & -beagle_instances 3 //partition 0 -> model 0 //partition 1 -> model 0 //partition 2 -> model 0 //Example 2: 3 partitions with 3 evolutionary models & -beagle_instances 2 //partitions 0 & 1 -> model 0 //partitions 2 & 3 -> model 1 //partitions 4 & 5 -> model 2 int mapPartition = partitionCount / evolutionaryProcessDelegates.size(); for (int i = 0; i < partitionCount; i++) { if (updatePartition[i] || updateAllPartitions) { EvolutionaryProcessDelegate evolutionaryProcessDelegate = evolutionaryProcessDelegates.get(i / (mapPartition)); /*if (evolutionaryProcessDelegates.size() == partitionCount) { evolutionaryProcessDelegate = evolutionaryProcessDelegates.get(i); } else { evolutionaryProcessDelegate = evolutionaryProcessDelegates.get(0); }*/ operations[k] = partialBufferHelper[i].getOffsetIndex(nodeNum); operations[k + 1] = writeScale[i]; operations[k + 2] = readScale[i]; operations[k + 3] = partialBufferHelper[i].getOffsetIndex(op.getLeftChild()); // source node 1 operations[k + 4] = evolutionaryProcessDelegate.getMatrixIndex(op.getLeftChild()); // source matrix 1 operations[k + 5] = partialBufferHelper[i].getOffsetIndex(op.getRightChild()); // source node 2 operations[k + 6] = evolutionaryProcessDelegate.getMatrixIndex(op.getRightChild()); // source matrix 2 operations[k + 7] = i; //TODO: we don't know the cumulateScaleBufferIndex here yet (see below) operations[k + 8] = Beagle.NONE; if (DEBUG) { if (k == 0 || (k == Beagle.PARTITION_OPERATION_TUPLE_SIZE)) { System.out.println("write = " + writeScale[i] + "; read = " + readScale[i] + "; parent = " + operations[k] + ", k = " + k + ", i = " + i); } } k += Beagle.PARTITION_OPERATION_TUPLE_SIZE; operationCount++; } } } beagle.updatePartialsByPartition(operations, operationCount); //double[] rootPartials = new double[totalPatternCount * stateCount]; //beagle.getPartials(rootIndex, 0, rootPartials); int[] cumulativeScaleIndices = new int[partitionCount]; for (int i = 0; i < partitionCount; i++) { cumulativeScaleIndices[i] = Beagle.NONE; if (useScaleFactors[i]) { if (recomputeScaleFactors[i] && (updatePartition[i] || updateAllPartitions)) { scaleBufferHelper[i].flipOffset(internalNodeCount); cumulativeScaleIndices[i] = scaleBufferHelper[i].getOffsetIndex(internalNodeCount); //TODO: check with Daniel if calling these methods using an iteration can be done more efficiently beagle.resetScaleFactorsByPartition(cumulativeScaleIndices[i], i); beagle.accumulateScaleFactorsByPartition(scaleBufferIndices[i], internalNodeCount, cumulativeScaleIndices[i], i); } else { cumulativeScaleIndices[i] = scaleBufferHelper[i].getOffsetIndex(internalNodeCount); } } } // double[] scaleFactors = new double[totalPatternCount]; // beagle.getLogScaleFactors(cumulateScaleBufferIndex, scaleFactors); // these could be set only when they change but store/restore would need to be considered for (int i = 0; i < siteRateModels.size(); i++) { double[] categoryWeights = this.siteRateModels.get(i).getCategoryProportions(); beagle.setCategoryWeights(i, categoryWeights); // This should probably explicitly be the state frequencies for the root node... double[] frequencies = evolutionaryProcessDelegates.get(i).getRootStateFrequencies(); beagle.setStateFrequencies(i, frequencies); } if (DEBUG) { for (int i = 0; i < partitionCount; i++) { System.out.println("useScaleFactors=" + useScaleFactors[i] + " recomputeScaleFactors=" + recomputeScaleFactors[i] + " (partition: " + i + ")"); } } /*System.out.println("partitionCount = " + partitionCount); for (int i = 0; i < partitionCount; i++) { System.out.println("partitionIndices[" + i + "] = " + partitionIndices[i]); }*/ int[] partitionIndices = new int[partitionCount]; int[] rootIndices = new int[partitionCount]; int[] categoryWeightsIndices = new int[partitionCount]; int[] stateFrequenciesIndices = new int[partitionCount]; double[] sumLogLikelihoods = new double[1]; double[] sumLogLikelihoodsByPartition = new double[partitionCount]; // create a list of partitions that have been updated int updatedPartitionCount = 0; for (int i = 0; i < partitionCount; i++) { if (updatePartition[i] || updateAllPartitions) { partitionIndices[updatedPartitionCount] = i; rootIndices [updatedPartitionCount] = partialBufferHelper[i].getOffsetIndex(rootNodeNumber); categoryWeightsIndices [updatedPartitionCount] = i % siteRateModels.size(); stateFrequenciesIndices[updatedPartitionCount] = i % siteRateModels.size(); cumulativeScaleIndices [updatedPartitionCount] = cumulativeScaleIndices[i]; updatedPartitionCount++; } } //TODO: check these arguments with Daniel //TODO: partitionIndices needs to be set according to which partitions need updating? beagle.calculateRootLogLikelihoodsByPartition( rootIndices, categoryWeightsIndices, stateFrequenciesIndices, cumulativeScaleIndices, partitionIndices, updatedPartitionCount, 1, sumLogLikelihoodsByPartition, sumLogLikelihoods); /*System.out.println(); for (int i = 0; i < partitionCount; i++) { System.out.println("partition " + i + " lnL = " + sumLogLikelihoodsByPartition[i]); }*/ // write updated partition likelihoods to the cached array for (int i = 0; i < updatedPartitionCount; i++) { cachedLogLikelihoodsByPartition[partitionIndices[i]] = sumLogLikelihoodsByPartition[i]; // clear the global flags updatePartition[partitionIndices[i]] = false; recomputeScaleFactors[partitionIndices[i]] = false; partitionWasUpdated[partitionIndices[i]] = true; } double tmpLogL = sumLogLikelihoods[0]; if (DEBUG) { for (int i = 0; i < partitionCount; i++) { System.out.println("partition " + i + ": " + cachedLogLikelihoodsByPartition[i] + (partitionWasUpdated[i] ? " [updated]" : "")); } } // If these are needed... // if (patternLogLikelihoods == null) { // patternLogLikelihoods = new double[totalPatternCount]; // } // beagle.getSiteLogLikelihoods(patternLogLikelihoods); updateSubstitutionModels(false); updateSiteRateModels(false); updateAllPartitions = true; if (Double.isNaN(tmpLogL) || Double.isInfinite(tmpLogL)) { // one or more of the updated partitions has underflowed if (DEBUG) { System.out.println("Double.isNaN(logL) || Double.isInfinite(logL)"); } for (int i = 0; i < updatedPartitionCount; i++) { if (Double.isNaN(sumLogLikelihoodsByPartition[i]) || Double.isInfinite(sumLogLikelihoodsByPartition[i])) { everUnderflowed[partitionIndices[i]] = true; } } if (firstRescaleAttempt) { if (delayRescalingUntilUnderflow || rescalingScheme == PartialsRescalingScheme.DELAYED) { // show a message but only every 1000 rescales if (rescalingMessageCount % 1000 == 0) { if (rescalingMessageCount > 0) { Logger.getLogger("dr.evomodel").info("Underflow calculating likelihood (" + rescalingMessageCount + " messages not shown)."); } else { Logger.getLogger("dr.evomodel").info("Underflow calculating likelihood. Attempting a rescaling... (" + getId() + ")"); } } rescalingMessageCount += 1; } for (int i = 0; i < updatedPartitionCount; i++) { if (delayRescalingUntilUnderflow || rescalingScheme == PartialsRescalingScheme.DELAYED) { if (Double.isNaN(sumLogLikelihoodsByPartition[i]) || Double.isInfinite(sumLogLikelihoodsByPartition[i])) { useScaleFactors[partitionIndices[i]] = true; recomputeScaleFactors[partitionIndices[i]] = true; updatePartition[partitionIndices[i]] = true; // turn off double buffer flipping so the next call overwrites the // underflowed buffers. Flip will be turned on again in storeState for // next step flip[partitionIndices[i]] = false; updateAllPartitions = false; if (DEBUG) { System.out.println("Double.isNaN(logL) || Double.isInfinite(logL) (partition index: " + partitionIndices[i] + ")"); } } } } firstRescaleAttempt = false; throw new LikelihoodUnderflowException(); } return Double.NEGATIVE_INFINITY; } else { for (int i = 0; i < updatedPartitionCount; i++) { if (partitionWasUpdated[partitionIndices[i]]) { if (!this.delayRescalingUntilUnderflow || everUnderflowed[partitionIndices[i]]) { if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) { if (!initialEvaluation) { rescalingCount[partitionIndices[i]]++; } } } partitionWasUpdated[partitionIndices[i]] = false; } //TODO: probably better to only switch back those booleans that were actually altered recomputeScaleFactors[partitionIndices[i]] = false; flip[partitionIndices[i]] = true; } firstRescaleAttempt = true; initialEvaluation = false; } //******************************************************************** double logL = 0.0; for (double l : cachedLogLikelihoodsByPartition) { logL += l; } return logL; } /*public void getPartials(int number, double[] partials) { int cumulativeBufferIndex = Beagle.NONE; // No need to rescale partials beagle.getPartials(partialBufferHelper.getOffsetIndex(number), cumulativeBufferIndex, partials); }*/ /*private void setPartials(int number, double[] partials) { beagle.setPartials(partialBufferHelper.getOffsetIndex(number), partials); }*/ @Override public void makeDirty() { updateSiteRateModels(); updateSubstitutionModels(); } @Override protected void handleModelChangedEvent(Model model, Object object, int index) { if (model instanceof SiteRateModel) { updateSiteRateModel((SiteRateModel)model); } else if (model instanceof BranchModel) { updateSubstitutionModel((BranchModel)model); } // Tell TreeDataLikelihood to update all nodes fireModelChanged(); } @Override protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { } /** * Stores the additional state other than model components */ @Override public void storeState() { for (int i = 0; i < partitionCount; i++) { partialBufferHelper[i].storeState(); categoryRateBufferHelper[i].storeState(); } for (EvolutionaryProcessDelegate evolutionaryProcessDelegate : evolutionaryProcessDelegates) { evolutionaryProcessDelegate.storeState(); } for (int i = 0; i < partitionCount; i++) { if (useScaleFactors[i] ) { // Only store when actually used scaleBufferHelper[i].storeState(); System.arraycopy(scaleBufferIndices[i], 0, storedScaleBufferIndices[i], 0, scaleBufferIndices[i].length); //storedRescalingCount = rescalingCount; } // turn on double buffering flipping (may have been turned off to enable a rescale) flip[i] = true; } System.arraycopy(cachedLogLikelihoodsByPartition, 0, storedCachedLogLikelihoodsByPartition, 0, cachedLogLikelihoodsByPartition.length); } /** * Restore the additional stored state */ @Override public void restoreState() { for (int i = 0; i < partitionCount; i++) { partialBufferHelper[i].restoreState(); categoryRateBufferHelper[i].restoreState(); } for (EvolutionaryProcessDelegate evolutionaryProcessDelegate : evolutionaryProcessDelegates) { evolutionaryProcessDelegate.restoreState(); } for (int i = 0; i < partitionCount; i++) { if (useScaleFactors[i]) { scaleBufferHelper[i].restoreState(); int[] tmp = storedScaleBufferIndices[i]; storedScaleBufferIndices[i] = scaleBufferIndices[i]; scaleBufferIndices[i] = tmp; //rescalingCount = storedRescalingCount; } } double[] tmp = cachedLogLikelihoodsByPartition; cachedLogLikelihoodsByPartition = storedCachedLogLikelihoodsByPartition; storedCachedLogLikelihoodsByPartition = tmp; } @Override public void setCallback(TreeDataLikelihood treeDataLikelihood) { // Callback not necessary } @Override protected void acceptState() { } // ************************************************************** // INSTANCE CITABLE // ************************************************************** @Override public Citation.Category getCategory() { return Citation.Category.FRAMEWORK; } @Override public String getDescription() { return "Using BEAGLE likelihood calculation library"; } @Override public List<Citation> getCitations() { return Collections.singletonList(CommonCitations.AYRES_2012_BEAGLE); } // ************************************************************** // INSTANCE VARIABLES // ************************************************************** private int nodeCount; private int tipCount; private int internalNodeCount; private int[] branchUpdateIndices; private double[] branchLengths; //provide per partition buffer indices private int[][] scaleBufferIndices; private int[][] storedScaleBufferIndices; private int[] operations; //allow flipping per partition private boolean[] flip; private BufferIndexHelper[] partialBufferHelper; private BufferIndexHelper[] scaleBufferHelper; private BufferIndexHelper[] categoryRateBufferHelper; private PartialsRescalingScheme rescalingScheme; private int rescalingFrequency = RESCALE_FREQUENCY; private boolean delayRescalingUntilUnderflow = true; //allow per partition rescaling private boolean[] useScaleFactors; private boolean[] recomputeScaleFactors; //keep track of underflow on a per partition basis private boolean[] everUnderflowed; private int[] rescalingCount; private int[] rescalingCountInner; private boolean firstRescaleAttempt; private int rescalingMessageCount = 0; // keep track of which partitions need to be updated private boolean updatePartition[]; private boolean updateAllPartitions; private boolean partitionWasUpdated[]; /** * the patternLists */ private List<PatternList> patternLists; private final DataType dataType; private final int partitionCount; /** * the pattern weights across all patterns */ private final double[] patternWeights; /** * The partition for each pattern */ private final int[] patternPartitions; /** * the number of patterns for each partition */ private final int[] patternCounts; /** * total number of patterns across all partitions */ private final int totalPatternCount; /** * the number of states in the data */ private final int stateCount; /** * the branch-site model for these sites */ private final List<BranchModel> branchModels = new ArrayList<BranchModel>(); /** * A delegate to handle substitution models on branches */ private final List<EvolutionaryProcessDelegate> evolutionaryProcessDelegates = new ArrayList<EvolutionaryProcessDelegate>(); /** * the site model for these sites */ private final List<SiteRateModel> siteRateModels = new ArrayList<SiteRateModel>(); /** * the number of rate categories */ private final int categoryCount; /** * the BEAGLE library instance */ private final Beagle beagle; /** * Cached log likelihood for each partition */ private double[] cachedLogLikelihoodsByPartition; private double[] storedCachedLogLikelihoodsByPartition; /** * Flag to specify that the substitution model has changed */ private final boolean[] updateSubstitutionModels; /** * Flag to specify that the site model has changed */ private final boolean[] updateSiteRateModels; /** * Flag to take into account the first likelihood evaluation when initiating the MCMC chain */ private boolean initialEvaluation = true; }