/*
* BeagleTreeLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.treelikelihood;
import beagle.*;
import dr.evolution.datatype.HiddenDataType;
import dr.evolution.tree.TreeUtils;
import dr.evomodel.branchmodel.BranchModel;
import dr.evomodel.branchmodel.EpochBranchModel;
import dr.evomodel.branchmodel.HomogeneousBranchModel;
import dr.evomodel.substmodel.MarkovModulatedSubstitutionModel;
import dr.evomodel.treedatalikelihood.BufferIndexHelper;
import dr.evomodelxml.treelikelihood.BeagleTreeLikelihoodParser;
import dr.evomodel.siteratemodel.GammaSiteRateModel;
import dr.evomodel.siteratemodel.SiteRateModel;
import dr.evomodel.substmodel.FrequencyModel;
import dr.evomodel.substmodel.nucleotide.HKY;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.app.beagle.tools.BeagleSequenceSimulator;
import dr.app.beagle.tools.Partition;
import dr.evolution.alignment.Alignment;
import dr.evolution.alignment.AscertainedSitePatterns;
import dr.evolution.alignment.PatternList;
import dr.evolution.alignment.UncertainSiteList;
import dr.evolution.datatype.Nucleotides;
import dr.evolution.io.NewickImporter;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.Taxon;
import dr.evolution.util.TaxonList;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.DefaultBranchRateModel;
import dr.evomodel.branchratemodel.StrictClockBranchRates;
import dr.evomodel.tipstatesmodel.TipStatesModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.ThreadAwareLikelihood;
import dr.math.MathUtils;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.CommonCitations;
import java.util.*;
import java.util.logging.Logger;
/**
* BeagleTreeLikelihoodModel - implements a Likelihood Function for sequences on a tree.
*
* @author Andrew Rambaut
* @author Alexei Drummond
* @author Marc Suchard
* @version $Id$
*/
@SuppressWarnings("serial")
@Deprecated // Switching to TreeDataLikelihood
public class BeagleTreeLikelihood extends AbstractSinglePartitionTreeLikelihood implements ThreadAwareLikelihood, Citable {
// This property is a comma-delimited list of resource numbers (0 == CPU) to
// allocate each BEAGLE instance to. If less than the number of instances then
// will wrap around.
private static final String RESOURCE_ORDER_PROPERTY = "beagle.resource.order";
private static final String PREFERRED_FLAGS_PROPERTY = "beagle.preferred.flags";
private static final String REQUIRED_FLAGS_PROPERTY = "beagle.required.flags";
private static final String SCALING_PROPERTY = "beagle.scaling";
private static final String RESCALE_FREQUENCY_PROPERTY = "beagle.rescale";
private static final String DELAY_SCALING_PROPERTY = "beagle.delay.scaling";
private static final String EXTRA_BUFFER_COUNT_PROPERTY = "beagle.extra.buffer.count";
private static final String FORCE_VECTORIZATION = "beagle.force.vectorization";
// Which scheme to use if choice not specified (or 'default' is selected):
private static final PartialsRescalingScheme DEFAULT_RESCALING_SCHEME = PartialsRescalingScheme.DYNAMIC;
private static int instanceCount = 0;
private static List<Integer> resourceOrder = null;
private static List<Integer> preferredOrder = null;
private static List<Integer> requiredOrder = null;
private static List<String> scalingOrder = null;
private static List<Integer> extraBufferOrder = null;
// Default frequency for complete recomputation of scaling factors under the 'dynamic' scheme
private static final int RESCALE_FREQUENCY = 100;
private static final int RESCALE_TIMES = 1;
private static final boolean RESCALING_OFF = false; // a debugging switch
private static final boolean DEBUG = false;
public BeagleTreeLikelihood(PatternList patternList,
TreeModel treeModel,
BranchModel branchModel,
SiteRateModel siteRateModel,
BranchRateModel branchRateModel,
TipStatesModel tipStatesModel,
boolean useAmbiguities,
PartialsRescalingScheme rescalingScheme,
boolean delayRescalingUntilUnderflow) {
this(patternList, treeModel, branchModel, siteRateModel, branchRateModel, tipStatesModel, useAmbiguities, rescalingScheme, delayRescalingUntilUnderflow, null);
}
public BeagleTreeLikelihood(PatternList patternList,
TreeModel treeModel,
BranchModel branchModel,
SiteRateModel siteRateModel,
BranchRateModel branchRateModel,
TipStatesModel tipStatesModel,
boolean useAmbiguities,
PartialsRescalingScheme rescalingScheme,
boolean delayRescalingUntilUnderflow,
Map<Set<String>, Parameter> partialsRestrictions) {
super(BeagleTreeLikelihoodParser.TREE_LIKELIHOOD, patternList, treeModel);
try {
final Logger logger = Logger.getLogger("dr.evomodel");
logger.info("\nUsing BEAGLE TreeLikelihood");
this.siteRateModel = siteRateModel;
addModel(this.siteRateModel);
this.branchModel = branchModel;
addModel(this.branchModel);
if (branchRateModel != null) {
this.branchRateModel = branchRateModel;
logger.info(" Branch rate model used: " + branchRateModel.getModelName());
} else {
this.branchRateModel = new DefaultBranchRateModel();
}
addModel(this.branchRateModel);
if (patternList instanceof UncertainSiteList ||
patternList.getDataType() instanceof HiddenDataType) {
useAmbiguities = true;
}
this.tipStatesModel = tipStatesModel;
this.categoryCount = this.siteRateModel.getCategoryCount();
this.tipCount = treeModel.getExternalNodeCount();
internalNodeCount = nodeCount - tipCount;
int compactPartialsCount = tipCount;
if (useAmbiguities) {
// if we are using ambiguities then we don't use tip partials
compactPartialsCount = 0;
}
// one partials buffer for each tip and two for each internal node (for store restore)
partialBufferHelper = new BufferIndexHelper(nodeCount, tipCount);
// one scaling buffer for each internal node plus an extra for the accumulation, then doubled for store/restore
scaleBufferHelper = new BufferIndexHelper(getScaleBufferCount(), 0);
// Attempt to get the resource order from the System Property
if (resourceOrder == null) {
resourceOrder = parseSystemPropertyIntegerArray(RESOURCE_ORDER_PROPERTY);
}
if (preferredOrder == null) {
preferredOrder = parseSystemPropertyIntegerArray(PREFERRED_FLAGS_PROPERTY);
}
if (requiredOrder == null) {
requiredOrder = parseSystemPropertyIntegerArray(REQUIRED_FLAGS_PROPERTY);
}
if (scalingOrder == null) {
scalingOrder = parseSystemPropertyStringArray(SCALING_PROPERTY);
}
if (extraBufferOrder == null) {
extraBufferOrder = parseSystemPropertyIntegerArray(EXTRA_BUFFER_COUNT_PROPERTY);
}
int extraBufferCount = -1; // default
if (extraBufferOrder.size() > 0) {
extraBufferCount = extraBufferOrder.get(instanceCount % extraBufferOrder.size());
}
substitutionModelDelegate = new SubstitutionModelDelegate(treeModel, branchModel, extraBufferCount);
// first set the rescaling scheme to use from the parser
this.rescalingScheme = rescalingScheme;
this.delayRescalingUntilUnderflow = delayRescalingUntilUnderflow;
int[] resourceList = null;
long preferenceFlags = 0;
long requirementFlags = 0;
if (scalingOrder.size() > 0) {
this.rescalingScheme = PartialsRescalingScheme.parseFromString(
scalingOrder.get(instanceCount % scalingOrder.size()));
}
if (resourceOrder.size() > 0) {
// added the zero on the end so that a CPU is selected if requested resource fails
resourceList = new int[]{resourceOrder.get(instanceCount % resourceOrder.size()), 0};
if (resourceList[0] > 0) {
preferenceFlags |= BeagleFlag.PROCESSOR_GPU.getMask(); // Add preference weight against CPU
}
}
if (preferredOrder.size() > 0) {
preferenceFlags = preferredOrder.get(instanceCount % preferredOrder.size());
}
if (requiredOrder.size() > 0) {
requirementFlags = requiredOrder.get(instanceCount % requiredOrder.size());
}
// Define default behaviour here
if (this.rescalingScheme == PartialsRescalingScheme.DEFAULT) {
//if GPU: the default is dynamic scaling in BEAST
if (resourceList != null && resourceList[0] > 1) {
this.rescalingScheme = DEFAULT_RESCALING_SCHEME;
} else { // if CPU: just run as fast as possible
// this.rescalingScheme = PartialsRescalingScheme.NONE;
// Dynamic should run as fast as none until first underflow
this.rescalingScheme = DEFAULT_RESCALING_SCHEME;
}
}
// to keep behaviour of the delayed scheme (always + delay)...
if (this.rescalingScheme == PartialsRescalingScheme.DELAYED) {
this.delayRescalingUntilUnderflow = true;
this.rescalingScheme = PartialsRescalingScheme.ALWAYS;
}
if (this.rescalingScheme == PartialsRescalingScheme.AUTO) {
preferenceFlags |= BeagleFlag.SCALING_AUTO.getMask();
useAutoScaling = true;
} else {
// preferenceFlags |= BeagleFlag.SCALING_MANUAL.getMask();
}
String r = System.getProperty(RESCALE_FREQUENCY_PROPERTY);
if (r != null) {
rescalingFrequency = Integer.parseInt(r);
if (rescalingFrequency < 1) {
rescalingFrequency = RESCALE_FREQUENCY;
}
}
String d = System.getProperty(DELAY_SCALING_PROPERTY);
if (d != null) {
this.delayRescalingUntilUnderflow = Boolean.parseBoolean(d);
}
// TODO Remove once issue #854 is fixed
for (int s = 0; s < substitutionModelDelegate.getSubstitutionModelCount(); ++s) {
if (substitutionModelDelegate.getSubstitutionModel(s) instanceof MarkovModulatedSubstitutionModel) {
this.rescalingScheme = PartialsRescalingScheme.ALWAYS;
this.delayRescalingUntilUnderflow = false;
}
}
if (preferenceFlags == 0 && resourceList == null) { // else determine dataset characteristics
if (stateCount == 4 && patternList.getPatternCount() < 10000) // TODO determine good cut-off
preferenceFlags |= BeagleFlag.PROCESSOR_CPU.getMask();
}
boolean forceVectorization = false;
String vectorizationString = System.getProperty(FORCE_VECTORIZATION);
if (vectorizationString != null) {
forceVectorization = true;
}
if (BeagleFlag.VECTOR_SSE.isSet(preferenceFlags) && (stateCount != 4)
&& !forceVectorization
) {
// @todo SSE doesn't seem to work for larger state spaces so for now we override the
// SSE option.
preferenceFlags &= ~BeagleFlag.VECTOR_SSE.getMask();
preferenceFlags |= BeagleFlag.VECTOR_NONE.getMask();
if (stateCount > 4 && this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) {
this.rescalingScheme = PartialsRescalingScheme.DELAYED;
}
}
if (!BeagleFlag.PRECISION_SINGLE.isSet(preferenceFlags)) {
// if single precision not explicitly set then prefer double
preferenceFlags |= BeagleFlag.PRECISION_DOUBLE.getMask();
}
if (substitutionModelDelegate.canReturnComplexDiagonalization()) {
requirementFlags |= BeagleFlag.EIGEN_COMPLEX.getMask();
}
// Check for matching state counts
int stateCount2 = branchModel.getRootFrequencyModel().getFrequencyCount();
if (stateCount != stateCount2) {
throw new RuntimeException("Pattern state count (" + stateCount
+ ") does not match substitution model state count (" + stateCount2 + ")");
}
instanceCount++;
beagle = BeagleFactory.loadBeagleInstance(
tipCount,
partialBufferHelper.getBufferCount(),
compactPartialsCount,
stateCount,
patternCount,
substitutionModelDelegate.getEigenBufferCount(),
substitutionModelDelegate.getMatrixBufferCount(),
categoryCount,
scaleBufferHelper.getBufferCount(), // Always allocate; they may become necessary
resourceList,
preferenceFlags,
requirementFlags
);
InstanceDetails instanceDetails = beagle.getDetails();
ResourceDetails resourceDetails = null;
if (instanceDetails != null) {
resourceDetails = BeagleFactory.getResourceDetails(instanceDetails.getResourceNumber());
if (resourceDetails != null) {
StringBuilder sb = new StringBuilder(" Using BEAGLE resource ");
sb.append(resourceDetails.getNumber()).append(": ");
sb.append(resourceDetails.getName()).append("\n");
if (resourceDetails.getDescription() != null) {
String[] description = resourceDetails.getDescription().split("\\|");
for (String desc : description) {
if (desc.trim().length() > 0) {
sb.append(" ").append(desc.trim()).append("\n");
}
}
}
sb.append(" with instance flags: ").append(instanceDetails.toString());
logger.info(sb.toString());
} else {
logger.info(" Error retrieving BEAGLE resource for instance: " + instanceDetails.toString());
}
} else {
logger.info(" No external BEAGLE resources available, or resource list/requirements not met, using Java implementation");
}
logger.info(" " + (useAmbiguities ? "Using" : "Ignoring") + " ambiguities in tree likelihood.");
logger.info(" With " + patternList.getPatternCount() + " unique site patterns.");
if (tipStatesModel != null) {
tipStatesModel.setTree(treeModel);
if (tipStatesModel.getModelType() == TipStatesModel.Type.PARTIALS) {
tipPartials = new double[patternCount * stateCount];
} else {
tipStates = new int[patternCount];
}
addModel(tipStatesModel);
}
for (int i = 0; i < tipCount; i++) {
// Find the id of tip i in the patternList
String id = treeModel.getTaxonId(i);
int index = patternList.getTaxonIndex(id);
if (index == -1) {
throw new TaxonList.MissingTaxonException("Taxon, " + id + ", in tree, " + treeModel.getId() +
", is not found in patternList, " + patternList.getId());
} else {
if (tipStatesModel != null) {
// using a tipPartials model.
// First set the observed states:
tipStatesModel.setStates(patternList, index, i, id);
if (tipStatesModel.getModelType() == TipStatesModel.Type.PARTIALS) {
// Then set the tip partials as determined by the model:
setPartials(beagle, tipStatesModel, i);
} else {
// or the tip states:
tipStatesModel.getTipStates(i, tipStates);
beagle.setTipStates(i, tipStates);
}
} else {
if (useAmbiguities) {
setPartials(beagle, patternList, index, i);
} else {
setStates(beagle, patternList, index, i);
}
}
}
}
if (patternList instanceof AscertainedSitePatterns) {
ascertainedSitePatterns = true;
}
this.partialsRestrictions = partialsRestrictions;
// hasRestrictedPartials = (partialsRestrictions != null);
if (hasRestrictedPartials) {
numRestrictedPartials = partialsRestrictions.size();
updateRestrictedNodePartials = true;
partialsMap = new Parameter[treeModel.getNodeCount()];
partials = new double[stateCount * patternCount * categoryCount];
} else {
numRestrictedPartials = 0;
updateRestrictedNodePartials = false;
}
beagle.setPatternWeights(patternWeights);
String rescaleMessage = " Using rescaling scheme : " + this.rescalingScheme.getText();
if (this.rescalingScheme == PartialsRescalingScheme.AUTO &&
resourceDetails != null &&
(resourceDetails.getFlags() & BeagleFlag.SCALING_AUTO.getMask()) == 0) {
// If auto scaling in BEAGLE is not supported then do it here
this.rescalingScheme = PartialsRescalingScheme.DYNAMIC;
rescaleMessage = " Auto rescaling not supported in BEAGLE, using : " + this.rescalingScheme.getText();
}
boolean parenthesis = false;
if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) {
rescaleMessage += " (rescaling every " + rescalingFrequency + " evaluations";
parenthesis = true;
}
if (this.delayRescalingUntilUnderflow) {
rescaleMessage += (parenthesis ? ", " : " (") + "delay rescaling until first overflow";
parenthesis = true;
}
rescaleMessage += (parenthesis ? ")" : "");
logger.info(rescaleMessage);
if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) {
everUnderflowed = false; // If false, BEAST does not rescale until first under-/over-flow.
}
updateSubstitutionModel = true;
updateSiteModel = true;
} catch (TaxonList.MissingTaxonException mte) {
throw new RuntimeException(mte.toString());
}
this.useAmbiguities = useAmbiguities;
hasInitialized = true;
}
private static List<Integer> parseSystemPropertyIntegerArray(String propertyName) {
List<Integer> order = new ArrayList<Integer>();
String r = System.getProperty(propertyName);
if (r != null) {
String[] parts = r.split(",");
for (String part : parts) {
try {
int n = Integer.parseInt(part.trim());
order.add(n);
} catch (NumberFormatException nfe) {
System.err.println("Invalid entry '" + part + "' in " + propertyName);
}
}
}
return order;
}
private static List<String> parseSystemPropertyStringArray(String propertyName) {
List<String> order = new ArrayList<String>();
String r = System.getProperty(propertyName);
if (r != null) {
String[] parts = r.split(",");
for (String part : parts) {
try {
String s = part.trim();
order.add(s);
} catch (NumberFormatException nfe) {
System.err.println("Invalid entry '" + part + "' in " + propertyName);
}
}
}
return order;
}
public TipStatesModel getTipStatesModel() {
return tipStatesModel;
}
public PatternList getPatternsList() {
return patternList;
}
public TreeModel getTreeModel() {
return treeModel;
}
public BranchModel getBranchModel() {
return branchModel;
}
public SiteRateModel getSiteRateModel() {
return siteRateModel;
}
public BranchRateModel getBranchRateModel() {
return branchRateModel;
}
public PartialsRescalingScheme getRescalingScheme() {
return rescalingScheme;
}
public boolean isDelayRescalingUntilUnderflow() {
return delayRescalingUntilUnderflow;
}
public Map<Set<String>, Parameter> getPartialsRestrictions() {
return partialsRestrictions;
}
public boolean useAmbiguities() {
return useAmbiguities;
}
protected int getScaleBufferCount() {
return internalNodeCount + 1;
}
/**
* Sets the partials from a sequence in an alignment.
*
* @param beagle beagle
* @param patternList patternList
* @param sequenceIndex sequenceIndex
* @param nodeIndex nodeIndex
*/
protected final void setPartials(Beagle beagle,
PatternList patternList,
int sequenceIndex,
int nodeIndex) {
double[] partials = new double[patternCount * stateCount * categoryCount];
boolean[] stateSet;
int v = 0;
for (int i = 0; i < patternCount; i++) {
if (patternList instanceof UncertainSiteList) {
((UncertainSiteList) patternList).fillPartials(sequenceIndex, i, partials, v);
v += stateCount;
// TODO Add this functionality to SimpleSiteList to avoid if statement here
} else {
int state = patternList.getPatternState(sequenceIndex, i);
stateSet = dataType.getStateSet(state);
for (int j = 0; j < stateCount; j++) {
if (stateSet[j]) {
partials[v] = 1.0;
} else {
partials[v] = 0.0;
}
v++;
}
}
}
// if there is more than one category then replicate the partials for each
int n = patternCount * stateCount;
int k = n;
for (int i = 1; i < categoryCount; i++) {
System.arraycopy(partials, 0, partials, k, n);
k += n;
}
beagle.setPartials(nodeIndex, partials);
}
/**
* Sets the partials from a sequence in an alignment.
*/
protected final void setPartials(Beagle beagle,
TipStatesModel tipStatesModel,
int nodeIndex) {
double[] partials = new double[patternCount * stateCount * categoryCount];
tipStatesModel.getTipPartials(nodeIndex, partials);
// if there is more than one category then replicate the partials for each
int n = patternCount * stateCount;
int k = n;
for (int i = 1; i < categoryCount; i++) {
System.arraycopy(partials, 0, partials, k, n);
k += n;
}
beagle.setPartials(nodeIndex, partials);
}
public int getPatternCount() {
return patternCount;
}
/**
* Sets the partials from a sequence in an alignment.
*
* @param beagle beagle
* @param patternList patternList
* @param sequenceIndex sequenceIndex
* @param nodeIndex nodeIndex
*/
protected final void setStates(Beagle beagle,
PatternList patternList,
int sequenceIndex,
int nodeIndex) {
int i;
int[] states = new int[patternCount];
for (i = 0; i < patternCount; i++) {
states[i] = patternList.getPatternState(sequenceIndex, i);
}
beagle.setTipStates(nodeIndex, states);
}
// public void setStates(int tipIndex, int[] states) {
// System.err.println("BTL:setStates");
// beagle.setTipStates(tipIndex, states);
// makeDirty();
// }
//
// public void getStates(int tipIndex, int[] states) {
// System.err.println("BTL:getStates");
// beagle.getTipStates(tipIndex, states);
// }
public final void setPatternWeights(double[] patternWeights) {
this.patternWeights = patternWeights;
beagle.setPatternWeights(patternWeights);
}
// **************************************************************
// ModelListener IMPLEMENTATION
// **************************************************************
/**
* Handles model changed events from the submodels.
*/
protected void handleModelChangedEvent(Model model, Object object, int index) {
fireModelChanged();
if (model == treeModel) {
if (object instanceof TreeModel.TreeChangedEvent) {
if (((TreeModel.TreeChangedEvent) object).isNodeChanged()) {
// If a node event occurs the node and its two child nodes
// are flagged for updating (this will result in everything
// above being updated as well. Node events occur when a node
// is added to a branch, removed from a branch or its height or
// rate changes.
updateNodeAndChildren(((TreeModel.TreeChangedEvent) object).getNode());
updateRestrictedNodePartials = true;
} else if (((TreeModel.TreeChangedEvent) object).isTreeChanged()) {
// Full tree events result in a complete updating of the tree likelihood
// This event type is now used for EmpiricalTreeDistributions.
// System.err.println("Full tree update event - these events currently aren't used\n" +
// "so either this is in error or a new feature is using them so remove this message.");
updateAllNodes();
updateRestrictedNodePartials = true;
} else {
// Other event types are ignored (probably trait changes).
//System.err.println("Another tree event has occured (possibly a trait change).");
}
}
} else if (model == branchRateModel) {
if (index == -1) {
updateAllNodes();
} else {
updateNode(treeModel.getNode(index));
}
} else if (model == branchModel) {
if (index == -1) {
updateSubstitutionModel = true;
updateAllNodes();
} else {
updateNode(treeModel.getNode(index));
}
} else if (model == siteRateModel) {
updateSiteModel = true;
updateAllNodes();
} else if (model == tipStatesModel) {
if (object instanceof Taxon) {
for (int i = 0; i < treeModel.getNodeCount(); i++)
if (treeModel.getNodeTaxon(treeModel.getNode(i)) != null && treeModel.getNodeTaxon(treeModel.getNode(i)).getId().equalsIgnoreCase(((Taxon) object).getId()))
updateNode(treeModel.getNode(i));
} else if (object instanceof Parameter) {
// ignore...
} else {
updateAllNodes();
}
} else {
throw new RuntimeException("Unknown componentChangedEvent");
}
super.handleModelChangedEvent(model, object, index);
}
@Override
public void makeDirty() {
super.makeDirty();
updateSiteModel = true;
updateSubstitutionModel = true;
updateRestrictedNodePartials = true;
}
// **************************************************************
// Model IMPLEMENTATION
// **************************************************************
/**
* Stores the additional state other than model components
*/
protected void storeState() {
partialBufferHelper.storeState();
substitutionModelDelegate.storeState();
if (useScaleFactors || useAutoScaling) { // Only store when actually used
scaleBufferHelper.storeState();
System.arraycopy(scaleBufferIndices, 0, storedScaleBufferIndices, 0, scaleBufferIndices.length);
// storedRescalingCount = rescalingCount;
}
super.storeState();
}
/**
* Restore the additional stored state
*/
protected void restoreState() {
updateSiteModel = true; // this is required to upload the categoryRates to BEAGLE after the restore
partialBufferHelper.restoreState();
substitutionModelDelegate.restoreState();
if (useScaleFactors || useAutoScaling) {
scaleBufferHelper.restoreState();
int[] tmp = storedScaleBufferIndices;
storedScaleBufferIndices = scaleBufferIndices;
scaleBufferIndices = tmp;
// rescalingCount = storedRescalingCount;
}
updateRestrictedNodePartials = true;
super.restoreState();
}
// int marcCount = 0;
// **************************************************************
// Likelihood IMPLEMENTATION
// **************************************************************
/**
* Calculate the log likelihood of the current state.
*
* @return the log likelihood.
*/
protected double calculateLogLikelihood() {
if (patternLogLikelihoods == null) {
patternLogLikelihoods = new double[patternCount];
}
if (branchUpdateIndices == null) {
branchUpdateIndices = new int[nodeCount];
branchLengths = new double[nodeCount];
scaleBufferIndices = new int[internalNodeCount];
storedScaleBufferIndices = new int[internalNodeCount];
}
if (operations == null) {
operations = new int[numRestrictedPartials + 1][internalNodeCount * Beagle.OPERATION_TUPLE_SIZE];
operationCount = new int[numRestrictedPartials + 1];
}
recomputeScaleFactors = false;
if (!this.delayRescalingUntilUnderflow || everUnderflowed) {
if (this.rescalingScheme == PartialsRescalingScheme.ALWAYS || this.rescalingScheme == PartialsRescalingScheme.DELAYED) {
useScaleFactors = true;
recomputeScaleFactors = true;
} else if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) {
useScaleFactors = true;
if (DEBUG) {
System.out.println("rescalingCount = " + rescalingCount);
System.out.println("rescalingCountInner = " + rescalingCountInner);
}
if (rescalingCount > rescalingFrequency) {
if (DEBUG) {
System.out.println("rescalingCount > rescalingFrequency");
}
rescalingCount = 0;
rescalingCountInner = 0;
}
if (rescalingCountInner < RESCALE_TIMES) {
recomputeScaleFactors = true;
updateAllNodes();
rescalingCountInner++;
}
rescalingCount++;
}
}
if (RESCALING_OFF) { // a debugging switch
useScaleFactors = false;
recomputeScaleFactors = false;
}
if (tipStatesModel != null) {
int tipCount = treeModel.getExternalNodeCount();
for (int index = 0; index < tipCount; index++) {
if (updateNode[index]) {
if (tipStatesModel.getModelType() == TipStatesModel.Type.PARTIALS) {
tipStatesModel.getTipPartials(index, tipPartials);
beagle.setTipPartials(index, tipPartials);
} else {
tipStatesModel.getTipStates(index, tipStates);
beagle.setTipStates(index, tipStates);
}
}
}
}
branchUpdateCount = 0;
operationListCount = 0;
if (hasRestrictedPartials) {
for (int i = 0; i <= numRestrictedPartials; i++) {
operationCount[i] = 0;
}
} else {
operationCount[0] = 0;
}
final NodeRef root = treeModel.getRoot();
traverse(treeModel, root, null, true);
if (DEBUG) {
System.out.println("operationCount = " + operationCount[operationListCount]);
}
if (updateSubstitutionModel) { // TODO More efficient to update only the substitution model that changed, instead of all
substitutionModelDelegate.updateSubstitutionModels(beagle);
// we are currently assuming a no-category model...
}
if (updateSiteModel) {
double[] categoryRates = this.siteRateModel.getCategoryRates();
if (categoryRates == null) {
// If this returns null then there was a numerical error calculating the category rates
// (probably a very small alpha) so reject the move.
return Double.NEGATIVE_INFINITY;
}
beagle.setCategoryRates(categoryRates);
}
if (branchUpdateCount > 0) {
substitutionModelDelegate.updateTransitionMatrices(
beagle,
branchUpdateIndices,
branchLengths,
branchUpdateCount);
}
if (COUNT_TOTAL_OPERATIONS) {
totalMatrixUpdateCount += branchUpdateCount;
for (int i = 0; i <= numRestrictedPartials; i++) {
totalOperationCount += operationCount[i];
}
}
double logL;
boolean done;
boolean firstRescaleAttempt = true;
do {
if (hasRestrictedPartials) {
for (int i = 0; i <= numRestrictedPartials; i++) {
beagle.updatePartials(operations[i], operationCount[i], Beagle.NONE);
if (i < numRestrictedPartials) {
// restrictNodePartials(restrictedIndices[i]);
}
}
} else {
beagle.updatePartials(operations[0], operationCount[0], Beagle.NONE);
}
int rootIndex = partialBufferHelper.getOffsetIndex(root.getNumber());
double[] categoryWeights = this.siteRateModel.getCategoryProportions();
// This should probably explicitly be the state frequencies for the root node...
double[] frequencies = substitutionModelDelegate.getRootStateFrequencies();
int cumulateScaleBufferIndex = Beagle.NONE;
if (useScaleFactors) {
if (recomputeScaleFactors) {
scaleBufferHelper.flipOffset(internalNodeCount);
cumulateScaleBufferIndex = scaleBufferHelper.getOffsetIndex(internalNodeCount);
beagle.resetScaleFactors(cumulateScaleBufferIndex);
beagle.accumulateScaleFactors(scaleBufferIndices, internalNodeCount, cumulateScaleBufferIndex);
} else {
cumulateScaleBufferIndex = scaleBufferHelper.getOffsetIndex(internalNodeCount);
}
} else if (useAutoScaling) {
beagle.accumulateScaleFactors(scaleBufferIndices, internalNodeCount, Beagle.NONE);
}
// these could be set only when they change but store/restore would need to be considered
beagle.setCategoryWeights(0, categoryWeights);
beagle.setStateFrequencies(0, frequencies);
double[] sumLogLikelihoods = new double[1];
if (DEBUG) {
System.out.println("useScaleFactors=" + useScaleFactors + " recomputeScaleFactors=" + recomputeScaleFactors);
}
beagle.calculateRootLogLikelihoods(new int[]{rootIndex}, new int[]{0}, new int[]{0},
new int[]{cumulateScaleBufferIndex}, 1, sumLogLikelihoods);
logL = sumLogLikelihoods[0];
if (DEBUG) {
System.out.println(logL);
// if (logL > -90000) {
// System.exit(0);
// }
}
beagle.getSiteLogLikelihoods(patternLogLikelihoods);
if (ascertainedSitePatterns) {
// Need to correct for ascertainedSitePatterns
beagle.getSiteLogLikelihoods(patternLogLikelihoods);
logL = getAscertainmentCorrectedLogLikelihood((AscertainedSitePatterns) patternList,
patternLogLikelihoods, patternWeights);
}
if (Double.isNaN(logL) || Double.isInfinite(logL)) {
if (DEBUG) {
System.out.println("Double.isNaN(logL) || Double.isInfinite(logL)");
}
everUnderflowed = true;
logL = Double.NEGATIVE_INFINITY;
if (firstRescaleAttempt && (delayRescalingUntilUnderflow || rescalingScheme == PartialsRescalingScheme.DELAYED)) {
// we have had a potential under/over flow so attempt a rescaling
if (rescalingScheme == PartialsRescalingScheme.DYNAMIC || (rescalingCount == 0)) {
// show a message but only every 1000 rescales
if (rescalingMessageCount % 1000 == 0) {
if (rescalingMessageCount > 0) {
Logger.getLogger("dr.evomodel").info("Underflow calculating likelihood (" + rescalingMessageCount + " messages not shown).");
} else {
Logger.getLogger("dr.evomodel").info("Underflow calculating likelihood. Attempting a rescaling...");
}
}
rescalingMessageCount += 1;
}
useScaleFactors = true;
recomputeScaleFactors = true;
branchUpdateCount = 0;
updateAllNodes();
if (hasRestrictedPartials) {
for (int i = 0; i <= numRestrictedPartials; i++) {
operationCount[i] = 0;
}
} else {
operationCount[0] = 0;
}
// traverse again but without flipping partials indices as we
// just want to overwrite the last attempt. We will flip the
// scale buffer indices though as we are recomputing them.
traverse(treeModel, root, null, false);
done = false; // Run through do-while loop again
firstRescaleAttempt = false; // Only try to rescale once
} else {
// we have already tried a rescale, not rescaling or always rescaling
// so just return the likelihood...
done = true;
}
} else {
done = true; // No under-/over-flow, then done
}
} while (!done);
// If these are needed...
//beagle.getSiteLogLikelihoods(patternLogLikelihoods);
//********************************************************************
// after traverse all nodes and patterns have been updated --
//so change flags to reflect this.
for (int i = 0; i < nodeCount; i++) {
updateNode[i] = false;
}
updateSubstitutionModel = false;
updateSiteModel = false;
//********************************************************************
return logL;
}
public void getPartials(int number, double[] partials) {
int cumulativeBufferIndex = Beagle.NONE;
/* No need to rescale partials */
beagle.getPartials(partialBufferHelper.getOffsetIndex(number), cumulativeBufferIndex, partials);
}
public boolean arePartialsRescaled() {
return useScaleFactors;
}
protected void setPartials(int number, double[] partials) {
beagle.setPartials(partialBufferHelper.getOffsetIndex(number), partials);
}
private void restrictNodePartials(int nodeIndex) {
Parameter restrictionParameter = partialsMap[nodeIndex];
if (restrictionParameter == null) {
return;
}
getPartials(nodeIndex, partials);
double[] restriction = restrictionParameter.getParameterValues();
final int partialsLengthPerCategory = stateCount * patternCount;
if (restriction.length == partialsLengthPerCategory) {
for (int i = 0; i < categoryCount; i++) {
componentwiseMultiply(partials, partialsLengthPerCategory * i, restriction, 0, partialsLengthPerCategory);
}
} else {
componentwiseMultiply(partials, 0, restriction, 0, partialsLengthPerCategory * categoryCount);
}
setPartials(nodeIndex, partials);
}
private void componentwiseMultiply(double[] a, final int offsetA, double[] b, final int offsetB, final int length) {
for (int i = 0; i < length; i++) {
a[offsetA + i] *= b[offsetB + i];
}
}
private void computeNodeToRestrictionMap() {
Arrays.fill(partialsMap, null);
for (Set<String> taxonNames : partialsRestrictions.keySet()) {
NodeRef node = TreeUtils.getCommonAncestorNode(treeModel, taxonNames);
partialsMap[node.getNumber()] = partialsRestrictions.get(taxonNames);
}
}
private double getAscertainmentCorrectedLogLikelihood(AscertainedSitePatterns patternList,
double[] patternLogLikelihoods,
double[] patternWeights) {
double logL = 0.0;
double ascertainmentCorrection = patternList.getAscertainmentCorrection(patternLogLikelihoods);
for (int i = 0; i < patternCount; i++) {
logL += (patternLogLikelihoods[i] - ascertainmentCorrection) * patternWeights[i];
}
return logL;
}
/**
* Traverse the tree calculating partial likelihoods.
*
* @param tree tree
* @param node node
* @param operatorNumber operatorNumber
* @param flip flip
* @return boolean
*/
private boolean traverse(Tree tree, NodeRef node, int[] operatorNumber, boolean flip) {
boolean update = false;
int nodeNum = node.getNumber();
NodeRef parent = tree.getParent(node);
if (operatorNumber != null) {
operatorNumber[0] = -1;
}
// First update the transition probability matrix(ices) for this branch
if (parent != null && updateNode[nodeNum]) {
final double branchRate;
synchronized (branchRateModel) {
branchRate = branchRateModel.getBranchRate(tree, node);
}
final double parentHeight = tree.getNodeHeight(parent);
final double nodeHeight = tree.getNodeHeight(node);
// Get the operational time of the branch
final double branchLength = branchRate * (parentHeight - nodeHeight);
if (branchLength < 0.0) {
throw new RuntimeException("Negative branch length: " + branchLength + " (parent: " + parent +
"; height: " + parentHeight + " - child: " + node + "height: " + nodeHeight + ")");
}
if (flip) {
substitutionModelDelegate.flipMatrixBuffer(nodeNum);
}
branchUpdateIndices[branchUpdateCount] = nodeNum;
branchLengths[branchUpdateCount] = branchLength;
branchUpdateCount++;
update = true;
}
// If the node is internal, update the partial likelihoods.
if (!tree.isExternal(node)) {
// Traverse down the two child nodes
NodeRef child1 = tree.getChild(node, 0);
final int[] op1 = {-1};
final boolean update1 = traverse(tree, child1, op1, flip);
NodeRef child2 = tree.getChild(node, 1);
final int[] op2 = {-1};
final boolean update2 = traverse(tree, child2, op2, flip);
// If either child node was updated then update this node too
if (update1 || update2) {
int x = operationCount[operationListCount] * Beagle.OPERATION_TUPLE_SIZE;
if (flip) {
// first flip the partialBufferHelper
partialBufferHelper.flipOffset(nodeNum);
}
final int[] operations = this.operations[operationListCount];
operations[x] = partialBufferHelper.getOffsetIndex(nodeNum);
if (useScaleFactors) {
// get the index of this scaling buffer
int n = nodeNum - tipCount;
if (recomputeScaleFactors) {
// flip the indicator: can take either n or (internalNodeCount + 1) - n
scaleBufferHelper.flipOffset(n);
// store the index
scaleBufferIndices[n] = scaleBufferHelper.getOffsetIndex(n);
operations[x + 1] = scaleBufferIndices[n]; // Write new scaleFactor
operations[x + 2] = Beagle.NONE;
} else {
operations[x + 1] = Beagle.NONE;
operations[x + 2] = scaleBufferIndices[n]; // Read existing scaleFactor
}
} else {
if (useAutoScaling) {
scaleBufferIndices[nodeNum - tipCount] = partialBufferHelper.getOffsetIndex(nodeNum);
}
operations[x + 1] = Beagle.NONE; // Not using scaleFactors
operations[x + 2] = Beagle.NONE;
}
operations[x + 3] = partialBufferHelper.getOffsetIndex(child1.getNumber()); // source node 1
operations[x + 4] = substitutionModelDelegate.getMatrixIndex(child1.getNumber()); // source matrix 1
operations[x + 5] = partialBufferHelper.getOffsetIndex(child2.getNumber()); // source node 2
operations[x + 6] = substitutionModelDelegate.getMatrixIndex(child2.getNumber()); // source matrix 2
operationCount[operationListCount]++;
update = true;
if (hasRestrictedPartials) {
// Test if this set of partials should be restricted
if (updateRestrictedNodePartials) {
// Recompute map
computeNodeToRestrictionMap();
updateRestrictedNodePartials = false;
}
if (partialsMap[nodeNum] != null) {
}
}
}
}
return update;
}
// **************************************************************
// INSTANCE VARIABLES
// **************************************************************
private int[] branchUpdateIndices;
private double[] branchLengths;
private int branchUpdateCount;
private int[] scaleBufferIndices;
private int[] storedScaleBufferIndices;
private int[][] operations;
private int operationListCount;
private int[] operationCount;
// private final boolean hasRestrictedPartials;
private static final boolean hasRestrictedPartials = false;
private final int numRestrictedPartials;
private final Map<Set<String>, Parameter> partialsRestrictions;
private Parameter[] partialsMap;
private double[] partials;
private boolean updateRestrictedNodePartials;
// private int[] restrictedIndices;
protected BufferIndexHelper partialBufferHelper;
protected BufferIndexHelper scaleBufferHelper;
protected final int tipCount;
protected final int internalNodeCount;
private PartialsRescalingScheme rescalingScheme;
private int rescalingFrequency = RESCALE_FREQUENCY;
private boolean delayRescalingUntilUnderflow = true;
protected boolean useScaleFactors = false;
private boolean useAutoScaling = false;
private boolean recomputeScaleFactors = false;
private boolean everUnderflowed = false;
private int rescalingCount = 0;
private int rescalingCountInner = 0;
// private int storedRescalingCount;
private int rescalingMessageCount = 0;
/**
* the branch-site model for these sites
*/
protected final BranchModel branchModel;
/**
* A delegate to handle substitution models on branches
*/
protected final SubstitutionModelDelegate substitutionModelDelegate;
/**
* the site model for these sites
*/
protected final SiteRateModel siteRateModel;
/**
* the branch rate model
*/
protected final BranchRateModel branchRateModel;
/**
* the tip partials model
*/
private final TipStatesModel tipStatesModel;
/**
* the pattern likelihoods
*/
protected double[] patternLogLikelihoods = null;
/**
* the number of rate categories
*/
protected int categoryCount;
/**
* an array used to transfer tip partials
*/
protected double[] tipPartials;
/**
* an array used to transfer tip states
*/
protected int[] tipStates;
/**
* the BEAGLE library instance
*/
protected Beagle beagle;
/**
* Flag to specify that the substitution model has changed
*/
protected boolean updateSubstitutionModel;
/**
* Flag to specify that the site model has changed
*/
protected boolean updateSiteModel;
// /***
// * Flag to specify if LikelihoodCore supports dynamic rescaling
// */
// private boolean dynamicRescaling = false;
/**
* Flag to specify if site patterns are acertained
*/
private boolean ascertainedSitePatterns = false;
/**
* Flag to specify if ambiguity codes are in use
*/
protected final boolean useAmbiguities;
public static void main(String[] args) {
try {
MathUtils.setSeed(666);
System.out.println("Test case 1: simulateOnePartition");
int sequenceLength = 1000;
ArrayList<Partition> partitionsList = new ArrayList<Partition>();
// create tree
NewickImporter importer = new NewickImporter(
"(SimSeq1:73.7468,(SimSeq2:25.256989999999995,SimSeq3:45.256989999999995):18.48981);");
Tree tree = importer.importTree(null);
TreeModel treeModel = new TreeModel(tree);
// create Frequency Model
Parameter freqs = new Parameter.Default(new double[]{0.25, 0.25,
0.25, 0.25});
FrequencyModel freqModel = new FrequencyModel(Nucleotides.INSTANCE,
freqs);
// create branch model
Parameter kappa1 = new Parameter.Default(1, 1);
Parameter kappa2 = new Parameter.Default(1, 1);
HKY hky1 = new HKY(kappa1, freqModel);
HKY hky2 = new HKY(kappa2, freqModel);
HomogeneousBranchModel homogenousBranchSubstitutionModel = new HomogeneousBranchModel(
hky1);
List<SubstitutionModel> substitutionModels = new ArrayList<SubstitutionModel>();
substitutionModels.add(hky1);
substitutionModels.add(hky2);
List<FrequencyModel> freqModels = new ArrayList<FrequencyModel>();
freqModels.add(freqModel);
Parameter epochTimes = new Parameter.Default(1, 20);
// create branch rate model
Parameter rate = new Parameter.Default(1, 0.001);
BranchRateModel branchRateModel = new StrictClockBranchRates(rate);
// create site model
GammaSiteRateModel siteRateModel = new GammaSiteRateModel(
"siteModel");
BranchModel homogeneousBranchModel = new HomogeneousBranchModel(hky1);
BranchModel epochBranchModel = new EpochBranchModel(treeModel, substitutionModels, epochTimes);
// create partition
Partition partition1 = new Partition(treeModel, //
homogenousBranchSubstitutionModel,//
siteRateModel, //
branchRateModel, //
freqModel, //
0, // from
sequenceLength - 1, // to
1 // every
);
partitionsList.add(partition1);
// feed to sequence simulator and generate data
BeagleSequenceSimulator simulator = new BeagleSequenceSimulator(partitionsList
// , sequenceLength
);
Alignment alignment = simulator.simulate(false, false);
BeagleTreeLikelihood nbtl = new BeagleTreeLikelihood(alignment, treeModel, homogeneousBranchModel, siteRateModel, branchRateModel, null, false, PartialsRescalingScheme.DEFAULT, false);
System.out.println("nBTL(homogeneous) = " + nbtl.getLogLikelihood());
nbtl = new BeagleTreeLikelihood(alignment, treeModel, epochBranchModel, siteRateModel, branchRateModel, null, false, PartialsRescalingScheme.DEFAULT, false);
System.out.println("nBTL(epoch) = " + nbtl.getLogLikelihood());
} catch (Exception e) {
e.printStackTrace();
System.exit(-1);
} // END: try-catch block
}
public Double getUpdateTimer() {
return Double.valueOf(substitutionModelDelegate.updateTime);
}
public Double getConvolveTimer() {
return Double.valueOf(substitutionModelDelegate.convolveTime);
}
public void getLogScalingFactors(int nodeIndex, double[] buffer) {
if (nodeIndex < tipCount) {
Arrays.fill(buffer, 0.0);
} else {
// final int scaleIndex = scaleBufferHelper.getOffsetIndex(nodeIndex - tipCount);
final int scaleIndex = scaleBufferIndices[nodeIndex - tipCount];
beagle.getLogScaleFactors(scaleIndex, buffer);
}
}
public double[] getSiteLogLikelihoods() {
getLogLikelihood();
double[] siteLogLikelihoods = new double[patternCount];
beagle.getSiteLogLikelihoods(siteLogLikelihoods);
return siteLogLikelihoods;
}
@Override
public Citation.Category getCategory() {
return Citation.Category.FRAMEWORK;
}
@Override
public String getDescription() {
return "Using BEAGLE likelihood calculation library";
}
public List<Citation> getCitations() {
return Collections.singletonList(CommonCitations.AYRES_2012_BEAGLE);
}
}//END: class