/*
* AbstractLikelihoodCore.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.oldevomodel.treelikelihood;
import java.util.Arrays;
/**
* AbstractLikelihoodCore - An abstract base class for LikelihoodCores
*
* @author Andrew Rambaut
* @version $Id: AbstractLikelihoodCore.java,v 1.11 2006/08/30 16:02:42 rambaut Exp $
*/
@Deprecated // Switching to BEAGLE
public abstract class AbstractLikelihoodCore implements LikelihoodCore {
protected int stateCount;
protected int nodeCount;
protected int patternCount;
protected int partialsSize;
protected int matrixSize;
protected int matrixCount;
protected boolean integrateCategories;
protected double[][][] partials;
protected int[][] states;
protected double[][][] matrices;
protected int[] currentMatricesIndices;
protected int[] storedMatricesIndices;
protected int[] currentPartialsIndices;
protected int[] storedPartialsIndices;
protected boolean useScaling = false;
protected double[][][] scalingFactors;
private double scalingThreshold = 1.0E-100;
/**
* Constructor
*
* @param stateCount number of states
*/
public AbstractLikelihoodCore(int stateCount) {
this.stateCount = stateCount;
}
/**
* initializes partial likelihood arrays.
*
* @param nodeCount the number of nodes in the tree
* @param patternCount the number of patterns
* @param matrixCount the number of matrices (i.e., number of categories)
* @param integrateCategories whether sites are being integrated over all matrices
*/
public void initialize(int nodeCount, int patternCount, int matrixCount, boolean integrateCategories) {
this.nodeCount = nodeCount;
this.patternCount = patternCount;
this.matrixCount = matrixCount;
this.integrateCategories = integrateCategories;
if (integrateCategories) {
partialsSize = patternCount * stateCount * matrixCount;
} else {
partialsSize = patternCount * stateCount;
}
partials = new double[2][nodeCount][];
currentMatricesIndices = new int[nodeCount];
storedMatricesIndices = new int[nodeCount];
currentPartialsIndices = new int[nodeCount];
storedPartialsIndices = new int[nodeCount];
states = new int[nodeCount][];
for (int i = 0; i < nodeCount; i++) {
partials[0][i] = null;
partials[1][i] = null;
states[i] = null;
}
matrixSize = stateCount * stateCount;
matrices = new double[2][nodeCount][matrixCount * matrixSize];
}
/**
* cleans up and deallocates arrays.
*/
public void finalize() throws java.lang.Throwable {
super.finalize();
nodeCount = 0;
patternCount = 0;
matrixCount = 0;
partials = null;
currentPartialsIndices = null;
storedPartialsIndices = null;
states = null;
matrices = null;
currentMatricesIndices = null;
storedMatricesIndices = null;
scalingFactors = null;
}
public void setUseScaling(boolean useScaling) {
this.useScaling = useScaling;
if (useScaling) {
scalingFactors = new double[2][nodeCount][patternCount];
}
}
/**
* Allocates partials for a node
*/
public void createNodePartials(int nodeIndex) {
this.partials[0][nodeIndex] = new double[partialsSize];
this.partials[1][nodeIndex] = new double[partialsSize];
}
/**
* Sets partials for a node
*/
public void setNodePartials(int nodeIndex, double[] partials) {
if (this.partials[0][nodeIndex] == null) {
createNodePartials(nodeIndex);
}
if (partials.length < partialsSize) {
int k = 0;
for (int i = 0; i < matrixCount; i++) {
System.arraycopy(partials, 0, this.partials[0][nodeIndex], k, partials.length);
k += partials.length;
}
} else {
System.arraycopy(partials, 0, this.partials[0][nodeIndex], 0, partials.length);
}
}
/**
* Allocates states for a node
*/
public void createNodeStates(int nodeIndex) {
this.states[nodeIndex] = new int[patternCount];
}
/**
* Sets states for a node
*/
public void setNodeStates(int nodeIndex, int[] states) {
if (this.states[nodeIndex] == null) {
createNodeStates(nodeIndex);
}
System.arraycopy(states, 0, this.states[nodeIndex], 0, patternCount);
}
/**
* Gets states for a node
*/
public void getNodeStates(int nodeIndex, int[] states) {
System.arraycopy(this.states[nodeIndex], 0, states, 0, patternCount);
}
public void setNodeMatrixForUpdate(int nodeIndex) {
currentMatricesIndices[nodeIndex] = 1 - currentMatricesIndices[nodeIndex];
}
/**
* Sets probability matrix for a node
*/
public void setNodeMatrix(int nodeIndex, int matrixIndex, double[] matrix) {
System.arraycopy(matrix, 0, matrices[currentMatricesIndices[nodeIndex]][nodeIndex],
matrixIndex * matrixSize, matrixSize);
}
/**
* Gets probability matrix for a node
*/
public void getNodeMatrix(int nodeIndex, int matrixIndex, double[] matrix) {
System.arraycopy(matrices[currentMatricesIndices[nodeIndex]][nodeIndex],
matrixIndex * matrixSize, matrix, 0, matrixSize);
}
public void setNodePartialsForUpdate(int nodeIndex) {
currentPartialsIndices[nodeIndex] = 1 - currentPartialsIndices[nodeIndex];
}
/**
* Sets the currently updating node partials for node nodeIndex. This may
* need to repeatedly copy the partials for the different category partitions
*/
public void setCurrentNodePartials(int nodeIndex, double[] partials) {
if (partials.length < partialsSize) {
int k = 0;
for (int i = 0; i < matrixCount; i++) {
System.arraycopy(partials, 0, this.partials[currentPartialsIndices[nodeIndex]][nodeIndex], k, partials.length);
k += partials.length;
}
} else {
System.arraycopy(partials, 0, this.partials[currentPartialsIndices[nodeIndex]][nodeIndex], 0, partials.length);
}
}
/**
* Calculates partial likelihoods at a node.
*
* @param nodeIndex1 the 'child 1' node
* @param nodeIndex2 the 'child 2' node
* @param nodeIndex3 the 'parent' node
*/
public void calculatePartials(int nodeIndex1, int nodeIndex2, int nodeIndex3) {
if (states[nodeIndex1] != null) {
if (states[nodeIndex2] != null) {
calculateStatesStatesPruning(
states[nodeIndex1], matrices[currentMatricesIndices[nodeIndex1]][nodeIndex1],
states[nodeIndex2], matrices[currentMatricesIndices[nodeIndex2]][nodeIndex2],
partials[currentPartialsIndices[nodeIndex3]][nodeIndex3]);
} else {
calculateStatesPartialsPruning(states[nodeIndex1], matrices[currentMatricesIndices[nodeIndex1]][nodeIndex1],
partials[currentPartialsIndices[nodeIndex2]][nodeIndex2], matrices[currentMatricesIndices[nodeIndex2]][nodeIndex2],
partials[currentPartialsIndices[nodeIndex3]][nodeIndex3]);
}
} else {
if (states[nodeIndex2] != null) {
calculateStatesPartialsPruning(states[nodeIndex2], matrices[currentMatricesIndices[nodeIndex2]][nodeIndex2],
partials[currentPartialsIndices[nodeIndex1]][nodeIndex1], matrices[currentMatricesIndices[nodeIndex1]][nodeIndex1],
partials[currentPartialsIndices[nodeIndex3]][nodeIndex3]);
} else {
calculatePartialsPartialsPruning(partials[currentPartialsIndices[nodeIndex1]][nodeIndex1], matrices[currentMatricesIndices[nodeIndex1]][nodeIndex1],
partials[currentPartialsIndices[nodeIndex2]][nodeIndex2], matrices[currentMatricesIndices[nodeIndex2]][nodeIndex2],
partials[currentPartialsIndices[nodeIndex3]][nodeIndex3]);
}
}
if (useScaling) {
scalePartials(nodeIndex3);
}
//
// int k =0;
// for (int i = 0; i < patternCount; i++) {
// double f = 0.0;
//
// for (int j = 0; j < stateCount; j++) {
// f += partials[currentPartialsIndices[nodeIndex3]][nodeIndex3][k];
// k++;
// }
// if (f == 0.0) {
// Logger.getLogger("error").severe("A partial likelihood (node index = " + nodeIndex3 + ", pattern = "+ i +") is zero for all states.");
// }
// }
}
/**
* Calculates partial likelihoods at a node when both children have states.
*/
protected abstract void calculateStatesStatesPruning(int[] states1, double[] matrices1,
int[] states2, double[] matrices2,
double[] partials3);
/**
* Calculates partial likelihoods at a node when one child has states and one has partials.
*/
protected abstract void calculateStatesPartialsPruning(int[] states1, double[] matrices1,
double[] partials2, double[] matrices2,
double[] partials3);
/**
* Calculates partial likelihoods at a node when both children have partials.
*/
protected abstract void calculatePartialsPartialsPruning(double[] partials1, double[] matrices1,
double[] partials2, double[] matrices2,
double[] partials3);
/**
* Calculates partial likelihoods at a node.
*
* @param nodeIndex1 the 'child 1' node
* @param nodeIndex2 the 'child 2' node
* @param nodeIndex3 the 'parent' node
* @param matrixMap a map of which matrix to use for each pattern (can be null if integrating over categories)
*/
public void calculatePartials(int nodeIndex1, int nodeIndex2, int nodeIndex3, int[] matrixMap) {
if (states[nodeIndex1] != null) {
if (states[nodeIndex2] != null) {
calculateStatesStatesPruning(
states[nodeIndex1], matrices[currentMatricesIndices[nodeIndex1]][nodeIndex1],
states[nodeIndex2], matrices[currentMatricesIndices[nodeIndex2]][nodeIndex2],
partials[currentPartialsIndices[nodeIndex3]][nodeIndex3], matrixMap);
} else {
calculateStatesPartialsPruning(
states[nodeIndex1], matrices[currentMatricesIndices[nodeIndex1]][nodeIndex1],
partials[currentPartialsIndices[nodeIndex2]][nodeIndex2], matrices[currentMatricesIndices[nodeIndex2]][nodeIndex2],
partials[currentPartialsIndices[nodeIndex3]][nodeIndex3], matrixMap);
}
} else {
if (states[nodeIndex2] != null) {
calculateStatesPartialsPruning(
states[nodeIndex2], matrices[currentMatricesIndices[nodeIndex2]][nodeIndex2],
partials[currentPartialsIndices[nodeIndex1]][nodeIndex1], matrices[currentMatricesIndices[nodeIndex1]][nodeIndex1],
partials[currentPartialsIndices[nodeIndex3]][nodeIndex3], matrixMap);
} else {
calculatePartialsPartialsPruning(
partials[currentPartialsIndices[nodeIndex1]][nodeIndex1], matrices[currentMatricesIndices[nodeIndex1]][nodeIndex1],
partials[currentPartialsIndices[nodeIndex2]][nodeIndex2], matrices[currentMatricesIndices[nodeIndex2]][nodeIndex2],
partials[currentPartialsIndices[nodeIndex3]][nodeIndex3], matrixMap);
}
}
if (useScaling) {
scalePartials(nodeIndex3);
}
}
/**
* Calculates partial likelihoods at a node when both children have states.
*/
protected abstract void calculateStatesStatesPruning(int[] states1, double[] matrices1,
int[] states2, double[] matrices2,
double[] partials3, int[] matrixMap);
/**
* Calculates partial likelihoods at a node when one child has states and one has partials.
*/
protected abstract void calculateStatesPartialsPruning(int[] states1, double[] matrices1,
double[] partials2, double[] matrices2,
double[] partials3, int[] matrixMap);
/**
* Calculates partial likelihoods at a node when both children have partials.
*/
protected abstract void calculatePartialsPartialsPruning(double[] partials1, double[] matrices1,
double[] partials2, double[] matrices2,
double[] partials3, int[] matrixMap);
public void integratePartials(int nodeIndex, double[] proportions, double[] outPartials) {
calculateIntegratePartials(partials[currentPartialsIndices[nodeIndex]][nodeIndex], proportions, outPartials);
}
/**
* Integrates partials across categories.
*
* @param inPartials the partials at the node to be integrated
* @param proportions the proportions of sites in each category
* @param outPartials an array into which the integrated partials will go
*/
protected abstract void calculateIntegratePartials(double[] inPartials, double[] proportions, double[] outPartials);
/**
* Scale the partials at a given node. This uses a scaling suggested by Ziheng Yang in
* Yang (2000) J. Mol. Evol. 51: 423-432
* <p/>
* This function looks over the partial likelihoods for each state at each pattern
* and finds the largest. If this is less than the scalingThreshold (currently set
* to 1E-40) then it rescales the partials for that pattern by dividing by this number
* (i.e., normalizing to between 0, 1). It then stores the log of this scaling.
* This is called for every internal node after the partials are calculated so provides
* most of the performance hit. Ziheng suggests only doing this on a proportion of nodes
* but this sounded like a headache to organize (and he doesn't use the threshold idea
* which improves the performance quite a bit).
*
* @param nodeIndex
*/
protected void scalePartials(int nodeIndex) {
int u = 0;
for (int i = 0; i < patternCount; i++) {
double scaleFactor = 0.0;
int v = u;
for (int k = 0; k < matrixCount; k++) {
for (int j = 0; j < stateCount; j++) {
if (partials[currentPartialsIndices[nodeIndex]][nodeIndex][v] > scaleFactor) {
scaleFactor = partials[currentPartialsIndices[nodeIndex]][nodeIndex][v];
}
v++;
}
v += (patternCount - 1) * stateCount;
}
if (scaleFactor < scalingThreshold) {
v = u;
for (int k = 0; k < matrixCount; k++) {
for (int j = 0; j < stateCount; j++) {
partials[currentPartialsIndices[nodeIndex]][nodeIndex][v] /= scaleFactor;
v++;
}
v += (patternCount - 1) * stateCount;
}
scalingFactors[currentPartialsIndices[nodeIndex]][nodeIndex][i] = Math.log(scaleFactor);
} else {
scalingFactors[currentPartialsIndices[nodeIndex]][nodeIndex][i] = 0.0;
}
u += stateCount;
}
}
/**
* This function returns the scaling factor for that pattern by summing over
* the log scalings used at each node. If scaling is off then this just returns
* a 0.
*
* @return the log scaling factor
*/
public double getLogScalingFactor(int pattern) {
double logScalingFactor = 0.0;
if (useScaling) {
for (int i = 0; i < nodeCount; i++) {
logScalingFactor += scalingFactors[currentPartialsIndices[i]][i][pattern];
}
}
return logScalingFactor;
}
public void getLogScalingFactors(int nodeIndex, double[] buffer) {
if (useScaling) {
System.arraycopy(scalingFactors[currentPartialsIndices[nodeIndex]][nodeIndex], 0,
buffer, 0, patternCount);
} else {
Arrays.fill(buffer, 0.0);
}
}
public boolean arePartialsRescaled() {
return useScaling;
}
/**
* Gets the partials for a particular node.
*
* @param nodeIndex the node
* @param outPartials an array into which the partials will go
*/
public void getPartials(int nodeIndex, double[] outPartials) {
double[] partials1 = partials[currentPartialsIndices[nodeIndex]][nodeIndex];
System.arraycopy(partials1, 0, outPartials, 0, partialsSize);
}
/**
* Store current state
*/
public void storeState() {
System.arraycopy(currentMatricesIndices, 0, storedMatricesIndices, 0, nodeCount);
System.arraycopy(currentPartialsIndices, 0, storedPartialsIndices, 0, nodeCount);
}
/**
* Restore the stored state
*/
public void restoreState() {
// Rather than copying the stored stuff back, just swap the pointers...
int[] tmp1 = currentMatricesIndices;
currentMatricesIndices = storedMatricesIndices;
storedMatricesIndices = tmp1;
int[] tmp2 = currentPartialsIndices;
currentPartialsIndices = storedPartialsIndices;
storedPartialsIndices = tmp2;
}
}