/*
* HomologyRecursion.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.oldevomodel.indel;
import dr.evolution.alignment.Alignment;
import dr.evolution.datatype.DataType;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.oldevomodel.substmodel.SubstitutionModel;
import dr.math.BFloat;
import java.util.HashMap;
public class HomologyRecursion {
/** Array of 0-1 vectors, encoding the alignment as columns */
IntMathVec[] iAlignment;
/** Array of arrays, encoding the sequences (in rows). No gaps. Range [0,iNumNucs). */
int[][] iSequences;
/**
* Array encoding the tree.
* iParent[iChild] is the parent of iChild, except -1 = root's parent.
* Parent must be > child, and leaf indices must correspond to indices into alignment vectors
*/
private int[] iParent;
/** Length of incoming branches to each tree node. Root is assumed to be infinity */
private double[] iTau;
/** Parameters of the TKF91 model */
double iLambda;
double iMu;
/** Number of nucleotides */
int iNumNucs;
/** Transition probability matrices, iTrans[node][from][to], for each tree node except root */
double[][][] iTrans;
/** Equilibrium probabilities for nucleotides. */
double[] iEquil;
/** Intermediate results for TKF91 model, initialised in recursion(), used in treeRecursion() */
private double iH[], iN[], iB[], iE[];
/** TKF91-related result for initialising the DP table */
private double iInitial;
/** Class that implements a fast version of treeLikelihood */
private NativeTreeLikelihood iNativeMethod;
/** Labels used in local array iState[] */
static private final int eFree = 0, ePossible = 1, eEdgeUsed = 2, eUsed = 3;
/** Max dimension of unalignable region */
static private final int cMaxUnalignDimension = 10;
/** Minimum edge length */
static private final double MIN_EDGE_LENGTH = 1e-3;
/** Error count */
static int sBigUnalignableRegion = 0;
String PrintDouble(double[] d) {
String str = "";
for(double aD : d) {
str = str + aD + " ";
}
return str;
}
void checkConsistency() {
// Checks whether variables are init'ed properly
int iCols = iAlignment.length;
int iRows = iAlignment[0].iV.length;
String iErr = "";
IntMathVec iPos = new IntMathVec( iRows );
for (int i=0; i<iCols; i++) {
for (int j=0; j<iRows; j++) {
if (iAlignment[i].iV[j] != 0 && iAlignment[i].iV[j] != 1) {
iErr = "Non-0/1 emissions.";
}
}
iPos.add( iAlignment[i] );
}
for (int j=0; j<iRows; j++) {
if (iSequences[j].length != iPos.iV[j]) {
iErr = "Bad sequences length";
}
for (int i=0; i<iPos.iV[j]; i++) {
if ((iSequences[j][i]<0)||(iSequences[j][i]>=iNumNucs)) {
iErr = "Nucleotide codes in iSequences not in range, " + iSequences[j][i] + "not in [0,"+(iNumNucs-1)+"]";
}
}
}
int[] iChildren = new int[ iParent.length ];
for (int i=0; i<iParent.length-1; i++)
iChildren[ iParent[i] ]++;
for (int i=0; i<iRows; i++)
if (iChildren[i] != 0) {
iErr = "Num tips does not correspond to num sequences, or bad tree";
}
for (int i=iRows; i<iParent.length; i++)
if (iChildren[i] != 2) {
iErr = "Bad tree - not binary? Or too many tips?";
}
if (iParent[iParent.length-1] != -1) {
iErr = "Bad tree - root not in final position, or not labeled -1";
}
for (int i=0; i<iParent.length-1; i++) {
if (iTau[i] < MIN_EDGE_LENGTH) {
iErr = "";//Bad tree - edge lengths below threshold (edge " + i + ", len " + iTau[i] + ")";
iTau[i] = MIN_EDGE_LENGTH;
}
}
//for (int i=0; i < iAlignment.length; i++) {
// System.out.println(iAlignment[i]);
//}
if ( !iErr.equals("") )
System.out.println(iErr);
}
public HomologyRecursion() {
// Nothing? Initialisation is done manually through init() below.
}
public void init(Tree tree, Alignment alignment, SubstitutionModel substModel,
double mutationRate, double lengthDistr, double deathRate) {
// initialize the iParent and iTau arrays based on the given tree.
initTree(tree, mutationRate);
int[] treeIndex = new int[tree.getTaxonCount()];
for (int i =0; i < treeIndex.length; i++) {
treeIndex[i] = tree.getTaxonIndex(alignment.getTaxonId(i));
//System.out.println("alignment[" + i + "] = tree[" + treeIndex[i] + "]");
}
// initialize the iAlignment array from the given alignment.
initAlignment(alignment, treeIndex);
// initialize the iSequences array from the given alignment.
initSequences(alignment, treeIndex);
// initialize the iTrans array from the substitution model -- must be called after populating tree!
initSubstitutionModel(substModel);
// iLambda, iMu
iLambda = deathRate*lengthDistr;
iMu = deathRate;
// iNumNucs - alphabet size
DataType dataType = substModel.getDataType();
iNumNucs = dataType.getStateCount();
// Initialise TKF91 coefficients in iB, iH, iN, iE, and iInitial
initTKF91();
// Check
checkConsistency();
// Initialise native method
iNativeMethod = new NativeTreeLikelihood();
iNativeMethod.init(iNumNucs, cMaxUnalignDimension, iParent, iEquil, iTrans, iSequences, iN, iH, iE, iB);
}
private void initTree(Tree tree, double mutationRate) {
iParent = new int[tree.getNodeCount()];
iTau = new double[tree.getNodeCount()-1];
populate(tree, tree.getRoot(), new int[] {tree.getExternalNodeCount()}, mutationRate);
iParent[tree.getNodeCount()-1] = -1;
}
/**
* initialize the iTrans array from the substitution model -- must be called after populating tree!
*/
private void initSubstitutionModel(SubstitutionModel model) {
DataType dataType = model.getDataType();
int stateCount = dataType.getStateCount();
iTrans = new double[iTau.length][stateCount][stateCount];
double[] transProb = new double[stateCount*stateCount];
int count;
for (int i =0; i < iTau.length; i++) {
model.getTransitionProbabilities(iTau[i], transProb);
count = 0;
for (int j = 0; j < stateCount; j++) {
for (int k = 0; k < stateCount; k++) {
iTrans[i][j][k] = transProb[count];
count += 1;
}
}
}
// initialize equlibrium distribution
iEquil = new double[stateCount];
for (int k = 0; k < stateCount; k++) {
iEquil[k] = model.getFrequencyModel().getFrequency(k);
}
}
/**
* Initializes the iAlignment array from the given alignment.
*/
private void initAlignment(Alignment alignment, int[] treeIndex) {
int numSeqs = alignment.getSequenceCount();
int numSites = alignment.getSiteCount();
DataType dataType = alignment.getDataType();
int numStates = dataType.getStateCount();
iAlignment = new IntMathVec[numSites];
int[] column = new int[numSeqs];
for (int i =0; i < numSites; i++) {
for (int j = 0; j < numSeqs; j++) {
column[treeIndex[j]] = ((alignment.getState(j, i) >= numStates) ? 0 : 1);
}
iAlignment[i] = new IntMathVec(column);
}
}
/**
* Initializes the iSequence array from the given alignment.
*/
private void initSequences(Alignment alignment, int[] treeIndex) {
int numSeqs = alignment.getSequenceCount();
DataType dataType = alignment.getDataType();
int numStates = dataType.getStateCount();
iSequences = new int[numSeqs][];
for (int i = 0; i < numSeqs; i++) {
int seqLength = 0;
for (int j =0; j < alignment.getSiteCount(); j++) {
int state = alignment.getState(i, j);
if (state>=0 && state<numStates) { seqLength += 1;}
}
iSequences[treeIndex[i]] = new int[seqLength];
int count = 0;
for (int j =0; j < alignment.getSiteCount(); j++) {
int state = alignment.getState(i, j);
if (state>=0 && state<numStates) {
iSequences[treeIndex[i]][count] = state;
count += 1;
}
}
}
}
/**
* Populates the iParent and iTau arrays.
* @return the node number of the given node.
*/
private int populate(Tree tree, NodeRef node, int[] current, double mutationRate) {
int nodeNumber = node.getNumber();
// if its an external node just return the number
if (tree.isExternal(node)) {
iTau[nodeNumber] =
(tree.getNodeHeight(tree.getParent(node)) - tree.getNodeHeight(node)) * mutationRate;
return nodeNumber;
}
// if internal node, first let your children be assigned numbers
int[] childNumbers = new int[tree.getChildCount(node)];
for (int i = 0; i < tree.getChildCount(node); i++) {
childNumbers[i] = populate(tree, tree.getChild(node, i), current, mutationRate);
}
// now, pick the next available number
nodeNumber = current[0];
// if you are not the root, then record the branch length above you.
if (!tree.isRoot(node)) {
//iTau[nodeNumber] = tree.getBranchLength(node) * mutationRate;
iTau[nodeNumber] =
(tree.getNodeHeight(tree.getParent(node)) - tree.getNodeHeight(node)) * mutationRate;
}
// increment the next available number
current[0] += 1;
// now that you have your number, populate the iParent entries of your children.
for (int i = 0; i < tree.getChildCount(node); i++) {
iParent[childNumbers[i]] = nodeNumber;
}
// finally return your number so your parent can do the same.
return nodeNumber;
}
/**
* Populates the iB, iE, iH and iN arrays, and iInitial value
* @return void
*/
public void initTKF91() {
int iNumNodes = iParent.length;
double iBeta[] = new double[iNumNodes];
iB = new double[iNumNodes];
iE = new double[iNumNodes];
iH = new double[iNumNodes];
iN = new double[iNumNodes];
iInitial = 1.0;
for (int i=0; i<iNumNodes; i++) {
if (i==iNumNodes-1) {
// root
iBeta[ i ] = 1.0/iMu;
iH[i] = 0.0;
} else {
// internal node or tip
iBeta[ i ] = Math.exp((iLambda-iMu)*iTau[i]);
iBeta[ i ] = (1.0-iBeta[i])/(iMu-iLambda*iBeta[i]);
iH[i] = Math.exp(-iMu*iTau[i])*(1.0-iLambda*iBeta[i]);
}
iB[i] = iLambda*iBeta[i];
iE[i] = iMu*iBeta[i];
iN[i] = (1.0-iMu*iBeta[i])*(1.0-iB[i]) - iH[i];
iInitial *= (1.0 - iB[i]);
}
}
/**
* Calculates 'conditional factor' associated to emission encoded by iSignature, at position iPos.
* Nulls in iSignature represent no-emissions, others indicate homology family (1,2,...,cMaxUnalignDimension, inclusive)
*/
private double treeRecursion(IntMathVec iSignature, IntMathVec iPos) {
// See if we can use C code
if (iNativeMethod.isAvailable()) {
return iNativeMethod.treeRecursion(iSignature, iPos);
}
int iLeaves = iSignature.iV.length; // Dimension of alignment columns, i.e. number of leaves
int iNumNodes = iParent.length; // Number of internal nodes
int[] iHom = new int[ iNumNodes ]; // Homology for every node; 0 if node need not be homologous to an emitted nucleotide
int[] iHomNum = new int[ iNumNodes ]; // Number of homologous emissions accounted for by homologous nucleotide @ this node
int[] iHomMultiplicity = new int[ cMaxUnalignDimension+1 ]; // Number of emissions for each class of homologous nucleotides
int[] iChild1 = new int[ iNumNodes ], iChild2 = new int[ iNumNodes ]; // Left and right children
double[][] iFh = new double[ iNumNodes ][], iFi = new double[ iNumNodes ][]; // Fhomolgous and Finhomologous arrays
// Initialise stuff
for (int i=0; i<iNumNodes; i++) {
iFh[i] = new double[ iNumNucs ];
iFi[i] = new double[ iNumNucs + 1]; // Extra position for 'gap' entry
}
for (int i=0; i<iLeaves; i++) {
iHomMultiplicity[ iSignature.iV[ i ] ]++;
iHom[ i ] = iSignature.iV[ i ];
if (iSignature.iV[i] == 0)
iHomNum[ i ] = 0;
else
iHomNum[ i ] = 1;
}
// Loop over all nodes except root, and find out which nodes need carry nucleotides of what
// homology class. Also fill iChild* arrays, which point to the two children of every parent.
boolean iClashingHomology = false;
for (int i=0; i<iNumNodes-1; i++) {
if ((iHomNum[i] == iHomMultiplicity[ iHom[i] ]) || (iHom[i] == 0)) {
// This node is the MRCA of the homologous nucleotides iHom[i], or a gap - do nothing
} else {
if (iHom[ iParent[ i ]] == 0 || iHom[ iParent[ i ]] == iHom[i]) {
// This node is not yet MRCA, so node above must be homologous
iHom[iParent[i]] = iHom[i];
// If iHom[i]==0, iHomNum[i]==0.
iHomNum[iParent[i]] += iHomNum[i];
} else {
// Clashing homology; signal
iClashingHomology = true;
}
}
if (iChild1[ iParent[ i ] ] == 0)
iChild1[ iParent[ i ] ] = i;
else
iChild2[ iParent[ i ] ] = i;
}
// Bail out - cheaper than do this implicitly in the recursion below
if (iClashingHomology) {
return 0.0;
}
// Start recursion. First initialise the leaves
for (int i=0; i<iLeaves; i++) {
if (iSignature.iV[i] == 0) {
// gap
iFi[i][iNumNucs] = 1.0;
} else {
// nucleotide
iFh[i][ iSequences[i][iPos.iV[i]] ] = 1.0;
}
//System.out.println("Fi" + i +" = " + PrintDouble(iFi[i]));
//System.out.println("Fh" + i +" = " + PrintDouble(iFh[i]));
}
// Now do the recursion, bottom-up, on all internal nodes
for (int i=iLeaves; i<iNumNodes; i++) {
// Find out whether:
// 1- One homology family spanning tree intersects both child edges, i.e.
// a 'homologous' nucleotide should travel down both edges,
// 2- One homology family Spanning tree intersects one of the child edges, i.e.
// 'homologous' nucleotide must travel down a specific edge
// 3- No homology family's spanning tree intersect either child edge, i.e.
// a 'homologous' nucleotide may travel down either edge and pop out at a leaf,
// or may travel down both edges but has to die in at least one subtree, or
// an 'inhomologous' nucleotide may do whatever it likes here.
if ((iHom[i] != 0) && (iHom[i] == iHom[iChild1[i]]) && (iHom[i] == iHom[iChild2[i]])) {
// case 1
for (int j=0; j<iNumNucs; j++) {
double iL = 0.0;
double iR = 0.0;
for (int k=0; k<iNumNucs; k++) {
iL += iFh[iChild1[i]][k] * iH[iChild1[i]] * iTrans[iChild1[i]][j][k];
iR += iFh[iChild2[i]][k] * iH[iChild2[i]] * iTrans[iChild2[i]][j][k];
}
iFh[i][j] = iL*iR;
}
// Others: 0.0
} else if (iHom[i] != 0) {
// case 2. Figure out which is the homologous child, and which the inhomologous one.
int iChildH, iChildI;
if (iHom[i] == iHom[iChild1[i]]) {
iChildH = iChild1[i];
iChildI = iChild2[i];
} else {
iChildH = iChild2[i];
iChildI = iChild1[i];
}
for (int j=0; j<iNumNucs; j++) {
double iL = 0.0;
double iR = iE[iChildI] * iFi[iChildI][iNumNucs];
for (int k=0; k<iNumNucs; k++) {
iL += iFh[iChildH][k] * iH[iChildH] * iTrans[iChildH][j][k];
iR +=
(iFh[iChildI][k] + iFi[iChildI][k]) * (iN[iChildI] - iE[iChildI]*iB[iChildI]) * iEquil[k] +
iFi[iChildI][k] * iH[iChildI] * iTrans[iChildI][j][k];
}
iFh[i][j] = iL*iR;
}
// Others: 0.0
} else {
// case 3
int iC1 = iChild1[i];
int iC2 = iChild2[i];
//System.out.println(iC1);
//System.out.println(iC2);
for (int j=0; j<iNumNucs; j++) {
double iL1 = 0.0, iL2 = 0.0;
double iR1 = iE[iC1] * iFi[iC1][iNumNucs];
double iR2 = iE[iC2] * iFi[iC2][iNumNucs];
//System.out.println(iR1);
//System.out.println(iR2);
for (int k=0; k<iNumNucs; k++) {
iL1 += iFh[iC1][k] * iH[iC1] * iTrans[iC1][j][k];
iL2 += iFh[iC2][k] * iH[iC2] * iTrans[iC2][j][k];
iR1 +=
(iFh[iC1][k] + iFi[iC1][k]) * (iN[iC1] - iE[iC1]*iB[iC1]) * iEquil[k] +
iFi[iC1][k] * iH[iC1] * iTrans[iC1][j][k];
iR2 +=
(iFh[iC2][k] + iFi[iC2][k]) * (iN[iC2] - iE[iC2]*iB[iC2]) * iEquil[k] +
iFi[iC2][k] * iH[iC2] * iTrans[iC2][j][k];
}
iFh[i][j] = iL1*iR2 + iL2*iR1; // homology pops out below iC1 + homology pops out below iC2
iFi[i][j] = iR1*iR2; // no homology with j below i.
}
double iL = iFi[iC1][iNumNucs];
double iR = iFi[iC2][iNumNucs];
for (int j=0; j<iNumNucs; j++) {
iL -= iB[iC1] * (iFi[iC1][j] + iFh[iC1][j]) * iEquil[j];
iR -= iB[iC2] * (iFi[iC2][j] + iFh[iC2][j]) * iEquil[j];
}
iFi[i][iNumNucs] = iL*iR;
}
} // recursion over the internal nodes
// Now calculate the final result
int iRoot = iNumNodes-1;
double iResult = iFi[iRoot][iNumNucs];
for (int i=0; i<iNumNucs; i++)
iResult -= (iFi[iRoot][i] + iFh[iRoot][i]) * iB[iRoot] * iEquil[i];
return iResult;
} // treeRecursion
/**
* @return the logLikelihood of the alignment given tree, substitution model, et cetera
*/
public double recursion() {
// Initialise
int iLen = iAlignment.length;
int iLeaves = iAlignment[0].iV.length;
int iFirstNotUsed = 0; // First not-'used' alignment vector (for efficiency)
int iState[] = new int[ iLen ]; // Helper array, to traverse the region in the DP table corresp. to the alignment
IntMathVec iPos = new IntMathVec( iLeaves ); // Current position; sum of all used vectors
HashMap<IntMathVec, BFloat> iTable = new HashMap<IntMathVec, BFloat>();
// Calculate correction factor for null emissions ("wing folding", or linear equation solving.)
double iNullEmissionFac = treeRecursion( iPos, iPos );
//System.out.println("Null emisison: " + iNullEmissionFac);
// Enter first probability into DP table
iTable.put( iPos, new BFloat(iInitial / iNullEmissionFac) );
// Array of possible vector indices, used in inner loop
int[] iPossibles = new int[cMaxUnalignDimension];
do {
// Find all possible vectors from current position, iPos
IntMathVec iMask = new IntMathVec( iLeaves );
int iPtr;
int iNumPossible = 0;
for (iPtr = iFirstNotUsed; iMask.zeroEntry() && iPtr<iLen; iPtr++) {
if (iState[ iPtr ] != eUsed) {
if (iMask.innerProduct( iAlignment[iPtr] ) == 0) {
iState[ iPtr ] = ePossible;
//System.out.println("Accepting column " + iAlignment[iPtr] + " as number " + iNumPossible);
if (iNumPossible == cMaxUnalignDimension) {
// This gets too hairy - bail out.
sBigUnalignableRegion++;
System.err.println("We bailed out cause it was hairy: iNumPossible=" + iNumPossible);
return Double.NEGATIVE_INFINITY;
}
iPossibles[iNumPossible++] = iPtr;
}
iMask.add( iAlignment[iPtr] );
}
}
// Loop over all combinations of possible vectors, which define edges from
// iPos to another possible position, by ordinary binary counting.
IntMathVec iNewPos = new IntMathVec( iPos );
IntMathVec iSignature = new IntMathVec( iPos.iV.length );
int iPosPtr;
boolean iUnusedPos;
boolean iFoundNonZero;
do {
// Find next combination
iFoundNonZero = false;
for (iPosPtr = iNumPossible - 1; iPosPtr >= 0; --iPosPtr) {
int iCurPtr = iPossibles[ iPosPtr ];
if (iState[ iCurPtr ] == ePossible) {
iState[ iCurPtr ] = eEdgeUsed;
iNewPos.add( iAlignment[ iCurPtr ] );
// Compute signature vector
iSignature.addMultiple( iAlignment[ iCurPtr ], iPosPtr+1 );
// Signal: non-zero combination found, and stop
iFoundNonZero = true;
iPosPtr = 0;
} else {
// It was eEdgeUsed (i.e., digit == 1), so reset digit and continue
iState[ iCurPtr ] = ePossible;
iNewPos.subtract( iAlignment[ iCurPtr ] );
iSignature.addMultiple( iAlignment[ iCurPtr ], -iPosPtr-1 );
}
}
if (iFoundNonZero) {
//System.out.print("Reading from pos " + iPos);
BFloat iLeft = (BFloat)(iTable.get( iPos )).clone();
//System.out.print(" left=" + iLeft);
BFloat iRightObj = iTable.get( iNewPos );
BFloat iRight;
if (iRightObj == null) {
iUnusedPos = true;
iRight = new BFloat(0);
} else {
iRight = iRightObj;
iUnusedPos = false;
}
double iTransFac = (-treeRecursion( iSignature, iPos )) / iNullEmissionFac;
//System.out.print(" fac=" + iTransFac);
iLeft.multiply( iTransFac );
iRight.add( iLeft );
//System.out.print(" sig=" + iSignature);
// And store
//System.out.println(" Storing pos " + iNewPos+ " val=" + iRight);
// If we are storing a value at a previously unused position, make sure we use a fresh key object
if (iUnusedPos) {
iTable.put( iNewPos.clone(), iRight );
} else {
iTable.put( iNewPos, iRight );
}
}
} while (iFoundNonZero);
// Now find next entry in DP table. Use farthest unused vector
--iPtr;
while (iPtr >= 0 && iState[iPtr] != ePossible) {
// Undo any possible used vector that we encounter
if (iState[iPtr] == eUsed) {
iPos.subtract( iAlignment[iPtr] );
iState[iPtr] = eFree;
}
--iPtr;
}
if (iPtr == -1) {
// No more unused vectors, so we also fell through the edge loop above,
// hence iNewPos contains the final position
//System.out.println("Returning " + (BFloat)iTable.get( iNewPos ) + " = " + ((BFloat)iTable.get( iNewPos ) ).log());
return (iTable.get( iNewPos )).log();
}
// Now use this farthest-out possible vector
iState[iPtr] = eUsed;
iPos.add( iAlignment[iPtr] );
if (iPtr <= iFirstNotUsed)
iFirstNotUsed++;
} while (true);
} // recursion
}