/*
* BeagleOperationReport.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.treelikelihood;
import beagle.Beagle;
import dr.evomodelxml.treelikelihood.BeagleOperationParser;
import dr.evomodel.siteratemodel.GammaSiteRateModel;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.evolution.alignment.Alignment;
import dr.evolution.alignment.PatternList;
import dr.evolution.datatype.DataType;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.TaxonList;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.tipstatesmodel.TipStatesModel;
import dr.math.matrixAlgebra.Vector;
import java.io.PrintWriter;
/**
* BeagleTreeLikelihoodModel - implements a Likelihood Function for sequences on a tree.
*
* @author Andrew Rambaut
* @author Alexei Drummond
* @author Marc Suchard
* @version $Id$
*/
public class BeagleOperationReport extends AbstractSinglePartitionTreeLikelihood {
/**
* the patternList
*/
protected PatternList patternList = null;
protected DataType dataType = null;
/**
* the pattern weights
*/
protected double[] patternWeights;
/**
* the number of patterns
*/
protected int patternCount;
/**
* the number of states in the data
*/
protected int stateCount;
/**
* Flags to specify which patterns are to be updated
*/
protected boolean[] updatePattern = null;
public BeagleOperationReport(TreeModel treeModel, PatternList patternList, BranchRateModel branchRateModel, GammaSiteRateModel siteRateModel, Alignment alignment, PrintWriter branch, PrintWriter operation) {
super(BeagleOperationParser.OPERATION_REPORT, patternList, treeModel);
boolean useAmbiguities = false;
this.branchRateModel = branchRateModel;
this.branchWriter = branch;
this.operationWriter = operation;
this.alignment = alignment;
this.substitutionModel = siteRateModel.getSubstitutionModel();
try {
this.tipCount = treeModel.getExternalNodeCount();
internalNodeCount = nodeCount - tipCount;
int compactPartialsCount = tipCount;
// one partials buffer for each tip and two for each internal node (for store restore)
partialBufferHelper = new BufferIndexHelper(nodeCount, tipCount);
// two eigen buffers for each decomposition for store and restore.
eigenBufferHelper = new BufferIndexHelper(eigenCount, 0);
// two matrices for each node less the root
matrixBufferHelper = new BufferIndexHelper(nodeCount, 0);
for (int i = 0; i < tipCount; i++) {
// Find the id of tip i in the patternList
String id = treeModel.getTaxonId(i);
int index = patternList.getTaxonIndex(id);
if (index == -1) {
throw new TaxonList.MissingTaxonException("Taxon, " + id + ", in tree, " + treeModel.getId() +
", is not found in patternList, " + patternList.getId());
} else {
if (useAmbiguities) {
setPartials(beagle, patternList, index, i);
} else {
setStates(beagle, patternList, id, index, i);
}
}
}
} catch (TaxonList.MissingTaxonException mte) {
throw new RuntimeException(mte.toString());
}
hasInitialized = true;
}
public String toString() {
calculateLogLikelihood();
return super.toString();
}
public TreeModel getTreeModel() {
return treeModel;
}
/**
* Sets the partials from a sequence in an alignment.
*
* @param beagle beagle
* @param patternList patternList
* @param sequenceIndex sequenceIndex
* @param nodeIndex nodeIndex
*/
protected final void setPartials(Beagle beagle,
PatternList patternList,
int sequenceIndex,
int nodeIndex) {
double[] partials = new double[patternCount * stateCount * categoryCount];
boolean[] stateSet;
int v = 0;
for (int i = 0; i < patternCount; i++) {
int state = patternList.getPatternState(sequenceIndex, i);
stateSet = dataType.getStateSet(state);
for (int j = 0; j < stateCount; j++) {
if (stateSet[j]) {
partials[v] = 1.0;
} else {
partials[v] = 0.0;
}
v++;
}
}
// if there is more than one category then replicate the partials for each
int n = patternCount * stateCount;
int k = n;
for (int i = 1; i < categoryCount; i++) {
System.arraycopy(partials, 0, partials, k, n);
k += n;
}
System.err.println("TODO Print partials");
// beagle.setPartials(nodeIndex, partials);
}
/**
* Sets the partials from a sequence in an alignment.
*/
protected final void setPartials(Beagle beagle,
TipStatesModel tipStatesModel,
int nodeIndex) {
double[] partials = new double[patternCount * stateCount * categoryCount];
tipStatesModel.getTipPartials(nodeIndex, partials);
// if there is more than one category then replicate the partials for each
int n = patternCount * stateCount;
int k = n;
for (int i = 1; i < categoryCount; i++) {
System.arraycopy(partials, 0, partials, k, n);
k += n;
}
System.err.println("TODO Print partials");
// beagle.setPartials(nodeIndex, partials);
}
public int getPatternCount() {
return patternCount;
}
/**
* Sets the partials from a sequence in an alignment.
*
* @param beagle beagle
* @param patternList patternList
* @param id
* @param sequenceIndex sequenceIndex
* @param nodeIndex nodeIndex
*/
protected final void setStates(Beagle beagle,
PatternList patternList,
String id, int sequenceIndex,
int nodeIndex) {
int i;
StringBuilder sb = new StringBuilder();
sb.append("/* ").append(id).append(" */\n\t\tmSeqs[").append(nodeIndex).append("] = \"");
sb.append(alignment.getAlignedSequenceString(sequenceIndex)).append("\";\n");
int[] states = new int[patternCount];
for (i = 0; i < patternCount; i++) {
states[i] = patternList.getPatternState(sequenceIndex, i);
}
if (alignmentString == null) {
alignmentString = new StringBuilder();
}
alignmentString.append(sb);
}
protected double calculateLogLikelihood() {
if (matrixUpdateIndices == null) {
matrixUpdateIndices = new int[eigenCount][nodeCount];
branchLengths = new double[eigenCount][nodeCount];
branchUpdateCount = new int[eigenCount];
// scaleBufferIndices = new int[internalNodeCount];
// storedScaleBufferIndices = new int[internalNodeCount];
}
if (operations == null) {
operations = new int[numRestrictedPartials + 1][internalNodeCount * Beagle.OPERATION_TUPLE_SIZE];
operationCount = new int[numRestrictedPartials + 1];
}
recomputeScaleFactors = false;
for (int i = 0; i < eigenCount; i++) {
branchUpdateCount[i] = 0;
}
operationListCount = 0;
if (hasRestrictedPartials) {
for (int i = 0; i <= numRestrictedPartials; i++) {
operationCount[i] = 0;
}
} else {
operationCount[0] = 0;
}
System.out.println(alignmentString.toString());
final NodeRef root = treeModel.getRoot();
traverse(treeModel, root, null, false); // Do not flip buffers
// Print out eigendecompositions
for (int i = 0; i < eigenCount; i++) {
if (branchUpdateCount[i] > 0) {
if (DEBUG_BEAGLE_OPERATIONS) {
StringBuilder sb = new StringBuilder();
sb.append("eval = ").append(new Vector(substitutionModel.getEigenDecomposition().getEigenValues())).append("\n");
sb.append("evec = ").append(new Vector(substitutionModel.getEigenDecomposition().getEigenVectors())).append("\n");
sb.append("ivec = ").append(new Vector(substitutionModel.getEigenDecomposition().getInverseEigenVectors())).append("\n");
sb.append("Branch count: ").append(branchUpdateCount[i]);
sb.append("\nNode indices:\n");
if (SINGLE_LINE) {
sb.append("int n[] = {");
}
for (int k = 0; k < branchUpdateCount[i]; ++k) {
if (SINGLE_LINE) {
sb.append(" ").append(matrixUpdateIndices[i][k]);
if (k < (branchUpdateCount[i] - 1)) {
sb.append(",");
}
} else {
sb.append(matrixUpdateIndices[i][k]).append("\n");
}
}
if (SINGLE_LINE) {
sb.append(" };\n");
}
sb.append("\nBranch lengths:\n");
if (SINGLE_LINE) {
sb.append("double b[] = {");
}
for (int k = 0; k < branchUpdateCount[i]; ++k) {
if (SINGLE_LINE) {
sb.append(" ").append(branchLengths[i][k]);
if (k < (branchUpdateCount[i] - 1)) {
sb.append(",");
}
} else {
sb.append(branchLengths[i][k]).append("\n");
}
}
if (SINGLE_LINE) {
sb.append(" };\n");
}
System.out.println(sb.toString());
}
}
}
if (DEBUG_BEAGLE_OPERATIONS) {
StringBuilder sb = new StringBuilder();
sb.append("Operation count: ").append(operationCount[0]);
sb.append("\nOperations:\n");
if (SINGLE_LINE) {
sb.append("BeagleOperation o[] = {");
}
for (int k = 0; k < operationCount[0] * Beagle.OPERATION_TUPLE_SIZE; ++k) {
if (SINGLE_LINE) {
sb.append(" ").append(operations[0][k]);
if (k < (operationCount[0] * Beagle.OPERATION_TUPLE_SIZE - 1)) {
sb.append(",");
}
} else {
sb.append(operations[0][k]).append("\n");
}
}
if (SINGLE_LINE) {
sb.append(" };\n");
}
sb.append("Use scale factors: ").append(useScaleFactors).append("\n");
System.out.println(sb.toString());
}
int rootIndex = partialBufferHelper.getOffsetIndex(root.getNumber());
System.out.println("Root node: " + rootIndex);
return 0.0;
}
/**
* Traverse the tree calculating partial likelihoods.
*
* @param tree tree
* @param node node
* @param operatorNumber operatorNumber
* @param flip flip
* @return boolean
*/
private boolean traverse(Tree tree, NodeRef node, int[] operatorNumber, boolean flip) {
boolean update = false;
int nodeNum = node.getNumber();
NodeRef parent = tree.getParent(node);
if (operatorNumber != null) {
operatorNumber[0] = -1;
}
// First update the transition probability matrix(ices) for this branch
if (parent != null && updateNode[nodeNum]) {
final double branchRate = branchRateModel.getBranchRate(tree, node);
// Get the operational time of the branch
final double branchTime = branchRate * (tree.getNodeHeight(parent) - tree.getNodeHeight(node));
if (branchTime < 0.0) {
throw new RuntimeException("Negative branch length: " + branchTime);
}
if (flip) {
// first flip the matrixBufferHelper
matrixBufferHelper.flipOffset(nodeNum);
}
// then set which matrix to update
final int eigenIndex = 0; //branchSubstitutionModel.getBranchIndex(tree, node);
final int updateCount = branchUpdateCount[eigenIndex];
matrixUpdateIndices[eigenIndex][updateCount] = matrixBufferHelper.getOffsetIndex(nodeNum);
branchLengths[eigenIndex][updateCount] = branchTime;
branchUpdateCount[eigenIndex]++;
update = true;
}
// If the node is internal, update the partial likelihoods.
if (!tree.isExternal(node)) {
// Traverse down the two child nodes
NodeRef child1 = tree.getChild(node, 0);
final int[] op1 = {-1};
final boolean update1 = traverse(tree, child1, op1, flip);
NodeRef child2 = tree.getChild(node, 1);
final int[] op2 = {-1};
final boolean update2 = traverse(tree, child2, op2, flip);
// If either child node was updated then update this node too
if (update1 || update2) {
int x = operationCount[operationListCount] * Beagle.OPERATION_TUPLE_SIZE;
if (flip) {
// first flip the partialBufferHelper
partialBufferHelper.flipOffset(nodeNum);
}
final int[] operations = this.operations[operationListCount];
operations[x] = partialBufferHelper.getOffsetIndex(nodeNum);
if (useScaleFactors) {
// get the index of this scaling buffer
int n = nodeNum - tipCount;
if (recomputeScaleFactors) {
// flip the indicator: can take either n or (internalNodeCount + 1) - n
scaleBufferHelper.flipOffset(n);
// store the index
scaleBufferIndices[n] = scaleBufferHelper.getOffsetIndex(n);
operations[x + 1] = scaleBufferIndices[n]; // Write new scaleFactor
operations[x + 2] = Beagle.NONE;
} else {
operations[x + 1] = Beagle.NONE;
operations[x + 2] = scaleBufferIndices[n]; // Read existing scaleFactor
}
} else {
if (useAutoScaling) {
scaleBufferIndices[nodeNum - tipCount] = partialBufferHelper.getOffsetIndex(nodeNum);
}
operations[x + 1] = Beagle.NONE; // Not using scaleFactors
operations[x + 2] = Beagle.NONE;
}
operations[x + 3] = partialBufferHelper.getOffsetIndex(child1.getNumber()); // source node 1
operations[x + 4] = matrixBufferHelper.getOffsetIndex(child1.getNumber()); // source matrix 1
operations[x + 5] = partialBufferHelper.getOffsetIndex(child2.getNumber()); // source node 2
operations[x + 6] = matrixBufferHelper.getOffsetIndex(child2.getNumber()); // source matrix 2
operationCount[operationListCount]++;
update = true;
}
}
return update;
}
// **************************************************************
// INSTANCE VARIABLES
// **************************************************************
private int eigenCount = 1;
private int[][] matrixUpdateIndices;
private double[][] branchLengths;
private int[] branchUpdateCount;
private int[] scaleBufferIndices;
private int[] storedScaleBufferIndices;
private int[][] operations;
private int operationListCount;
private int[] operationCount;
private static final boolean hasRestrictedPartials = false;
private final int numRestrictedPartials = 0;
protected BufferIndexHelper partialBufferHelper;
private final BufferIndexHelper eigenBufferHelper;
protected BufferIndexHelper matrixBufferHelper;
protected BufferIndexHelper scaleBufferHelper;
protected final int tipCount;
protected final int internalNodeCount;
private PartialsRescalingScheme rescalingScheme;
protected boolean useScaleFactors = false;
private boolean useAutoScaling = false;
private boolean recomputeScaleFactors = false;
private boolean everUnderflowed = false;
private int rescalingCount = 0;
private int rescalingCountInner = 0;
protected final BranchRateModel branchRateModel;
protected double[] patternLogLikelihoods = null;
/**
* the number of rate categories
*/
protected int categoryCount;
/**
* an array used to transfer tip partials
*/
protected double[] tipPartials;
/**
* an array used to transfer tip states
*/
protected int[] tipStates;
/**
* the BEAGLE library instance
*/
protected Beagle beagle;
/**
* Flag to specify that the substitution model has changed
*/
protected boolean updateSubstitutionModel;
/**
* Flag to specify that the site model has changed
*/
protected boolean updateSiteModel;
private static final boolean DEBUG_BEAGLE_OPERATIONS = true;
private static final boolean SINGLE_LINE = true;
private StringBuilder alignmentString;
private final PrintWriter branchWriter;
private final PrintWriter operationWriter;
private final SubstitutionModel substitutionModel;
private final Alignment alignment;
/**
* Set update flag for a pattern
*/
protected void updatePattern(int i) {
if (updatePattern != null) {
updatePattern[i] = true;
}
likelihoodKnown = false;
}
/**
* Set update flag for all patterns
*/
protected void updateAllPatterns() {
if (updatePattern != null) {
for (int i = 0; i < patternCount; i++) {
updatePattern[i] = true;
}
}
likelihoodKnown = false;
}
protected class BufferIndexHelper {
/**
* @param maxIndexValue the number of possible input values for the index
* @param minIndexValue the minimum index value to have the mirrored buffers
*/
BufferIndexHelper(int maxIndexValue, int minIndexValue) {
this.maxIndexValue = maxIndexValue;
this.minIndexValue = minIndexValue;
offsetCount = maxIndexValue - minIndexValue;
indexOffsets = new int[offsetCount];
storedIndexOffsets = new int[offsetCount];
}
public int getBufferCount() {
return 2 * offsetCount + minIndexValue;
}
void flipOffset(int i) {
if (i >= minIndexValue) {
indexOffsets[i - minIndexValue] = offsetCount - indexOffsets[i - minIndexValue];
} // else do nothing
}
int getOffsetIndex(int i) {
if (i < minIndexValue) {
return i;
}
return indexOffsets[i - minIndexValue] + i;
}
void getIndices(int[] outIndices) {
for (int i = 0; i < maxIndexValue; i++) {
outIndices[i] = getOffsetIndex(i);
}
}
void storeState() {
System.arraycopy(indexOffsets, 0, storedIndexOffsets, 0, indexOffsets.length);
}
void restoreState() {
int[] tmp = storedIndexOffsets;
storedIndexOffsets = indexOffsets;
indexOffsets = tmp;
}
private final int maxIndexValue;
private final int minIndexValue;
private final int offsetCount;
private int[] indexOffsets;
private int[] storedIndexOffsets;
}
}