/*
* SankoffParsimony.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evolution.parsimony;
import dr.evolution.alignment.PatternList;
import dr.evolution.alignment.Patterns;
import dr.evolution.datatype.Nucleotides;
import dr.evolution.tree.*;
import dr.evolution.util.Taxon;
import java.util.Set;
import java.util.TreeSet;
import java.util.Iterator;
/**
* Class for reconstructing characters using the Sankoff generalized parsimony methods. This will be
* slower than the Fitch algorithm but it allows Weighted Parsimony.
*
* @version $Id: SankoffParsimony.java,v 1.7 2005/06/29 16:54:18 beth Exp $
*
* @author Andrew Rambaut
* @author Alexei Drummond
*/
public class SankoffParsimony implements ParsimonyCriterion {
private final int stateCount;
private int[][] stateSets;
private double[][][] nodeScores;
private int[][] nodeStates;
private Tree tree = null;
private final PatternList patterns;
private final double[][] costMatrix;
private final boolean compressStates = true;
private boolean hasCalculatedSteps = false;
private boolean hasRecontructedStates = false;
private final double[] siteScores;
public SankoffParsimony(PatternList patterns) {
if (patterns == null) {
throw new IllegalArgumentException("The patterns cannot be null");
}
stateCount = patterns.getDataType().getStateCount();
this.costMatrix = new double[stateCount][stateCount];
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < stateCount; j++) {
if (i == j) {
costMatrix[i][j] = 0.0;
} else {
costMatrix[i][j] = 1.0;
}
}
}
this.patterns = patterns;
this.siteScores = new double[patterns.getPatternCount()];
}
public SankoffParsimony(PatternList patterns, double[][] costMatrix) {
if (patterns == null) {
throw new IllegalArgumentException("The patterns cannot be null");
}
stateCount = patterns.getDataType().getStateCount();
if (costMatrix.length != stateCount || costMatrix[0].length != stateCount) {
throw new IllegalArgumentException("The cost matrix is of the wrong dimension: expecting " + stateCount + " square");
}
this.costMatrix = costMatrix;
this.patterns = patterns;
this.siteScores = new double[patterns.getPatternCount()];
}
/**
* Calculates the minimum number of siteScores for the parsimony reconstruction of a
* a set of character patterns on a tree.
* @param tree a tree object to reconstruct the characters on
* @return number of parsimony siteScores
*/
public double[] getSiteScores(Tree tree) {
if (tree == null) {
throw new IllegalArgumentException("The tree cannot be null");
}
if (this.tree == null || this.tree != tree) {
this.tree = tree;
initialize();
}
if (!hasCalculatedSteps) {
calculateSteps(tree, tree.getRoot(), patterns);
if (compressStates) {
for (int i = 0; i < siteScores.length; i++) {
double[] Sr = nodeScores[tree.getRoot().getNumber()][i];
siteScores[i] = minScore(Sr, stateSets[i]);
}
} else {
for (int i = 0; i < siteScores.length; i++) {
double[] Sr = nodeScores[tree.getRoot().getNumber()][i];
siteScores[i] = minScore(Sr);
}
}
hasCalculatedSteps = true;
}
return siteScores;
}
public double getScore(Tree tree) {
getSiteScores(tree);
double score = 0;
for (int i = 0; i < patterns.getPatternCount(); i++) {
score += siteScores[i] * patterns.getPatternWeight(i);
}
return score;
}
/**
* Returns the reconstructed character nodeStates for a given node in the tree. If this method is repeatedly
* called with the same tree and patterns then only the first call will reconstruct the nodeStates and each
* subsequent call will return the stored nodeStates.
* @param tree a tree object to reconstruct the characters on
* @param node the node of the tree
* @return an array containing the reconstructed nodeStates for this node
*/
public int[] getStates(Tree tree, NodeRef node) {
getSiteScores(tree);
if (!hasRecontructedStates) {
if (compressStates) {
for (int i = 0; i < patterns.getPatternCount(); i++) {
nodeStates[tree.getRoot().getNumber()][i] = minState(nodeScores[tree.getRoot().getNumber()][i], stateSets[i]);
}
} else {
for (int i = 0; i < patterns.getPatternCount(); i++) {
nodeStates[tree.getRoot().getNumber()][i] = minState(nodeScores[tree.getRoot().getNumber()][i]);
}
}
reconstructStates(tree, tree.getRoot(), nodeStates[tree.getRoot().getNumber()]);
hasRecontructedStates = true;
}
return nodeStates[node.getNumber()];
}
private void initialize() {
hasCalculatedSteps = false;
hasRecontructedStates = false;
if (compressStates) {
stateSets = new int[patterns.getPatternCount()][];
}
nodeScores = new double[tree.getNodeCount()][patterns.getPatternCount()][];
nodeStates = new int[tree.getNodeCount()][patterns.getPatternCount()];
for (int i = 0; i < patterns.getPatternCount(); i++) {
int[] pattern = patterns.getPattern(i);
if (compressStates) {
Set observedStates = new TreeSet();
for (int j = 0; j < pattern.length; j++) {
boolean[] stateSet = patterns.getDataType().getStateSet(pattern[j]);
for (int k = 0; k < stateSet.length; k++) {
if (stateSet[k]) {
observedStates.add(new Integer(k));
}
}
}
stateSets[i] = new int[observedStates.size()];
Iterator iter = observedStates.iterator();
int j = 0;
while (iter.hasNext()) {
stateSets[i][j] = ((Integer)iter.next()).intValue();
j++;
}
}
for (int j = 0; j < tree.getExternalNodeCount(); j++) {
NodeRef node = tree.getExternalNode(j);
int state = pattern[patterns.getTaxonIndex(tree.getNodeTaxon(node).getId())];
boolean[] stateSet = patterns.getDataType().getStateSet(state);
nodeScores[j][i] = new double[stateCount];
for (int k = 0; k < stateCount; k++) {
if (stateSet[k]) {
nodeScores[j][i][k] = 0.0;
} else {
nodeScores[j][i][k] = Double.POSITIVE_INFINITY;
}
}
}
for (int j = 0; j < tree.getInternalNodeCount(); j++) {
nodeScores[j + tree.getExternalNodeCount()][i] = new double[stateCount];
}
}
}
/**
* This is the first pass of the Fitch algorithm. This calculates the set of nodeStates
* at each node and counts the total number of siteScores (the score). If that is all that
* is required then the second pass is not necessary.
* @param tree
* @param node
* @param patterns
*/
private void calculateSteps(Tree tree, NodeRef node, PatternList patterns) {
if (!tree.isExternal(node)) {
for (int i = 0; i < tree.getChildCount(node); i++) {
calculateSteps(tree, tree.getChild(node, i), patterns);
}
if (compressStates) {
for (int i = 0; i < patterns.getPatternCount(); i++) {
double[] Sc = nodeScores[tree.getChild(node, 0).getNumber()][i];
double[] Sa = nodeScores[node.getNumber()][i];
int[] set = stateSets[i];
for (int k = 0; k < set.length; k++) {
Sa[set[k]] = minCost(k, Sc, costMatrix, set);
}
for (int j = 1; j < tree.getChildCount(node); j++) {
Sc = nodeScores[tree.getChild(node, j).getNumber()][i];
for (int k = 0; k < set.length; k++) {
Sa[set[k]] += minCost(k, Sc, costMatrix, set);
}
}
}
} else {
for (int i = 0; i < patterns.getPatternCount(); i++) {
double[] Sc = nodeScores[tree.getChild(node, 0).getNumber()][i];
double[] Sa = nodeScores[node.getNumber()][i];
for (int k = 0; k < stateCount; k++) {
Sa[k] = minCost(k, Sc, costMatrix);
}
for (int j = 1; j < tree.getChildCount(node); j++) {
Sc = nodeScores[tree.getChild(node, j).getNumber()][i];
for (int k = 0; k < stateCount; k++) {
Sa[k] += minCost(k, Sc, costMatrix);
}
}
}
}
}
}
/**
* The second pass of the algorithm. This reconstructs the ancestral nodeStates at
* each node.
* @param tree
* @param node
* @param parentStates
*/
private void reconstructStates(Tree tree, NodeRef node, int[] parentStates) {
for (int i = 0; i < patterns.getPatternCount(); i++) {
double[] Sa = nodeScores[node.getNumber()][i];
if (compressStates) {
int[] set = stateSets[i];
int minState = set[0];
double minCost = Sa[minState] + costMatrix[parentStates[i]][minState];
for (int j = 1; j < set.length; j++) {
double c = Sa[set[j]] + costMatrix[parentStates[i]][set[j]];
if (c < minCost) {
minState = set[j];
minCost = c;
}
}
nodeStates[node.getNumber()][i] = minState;
} else {
int minState = 0;
double minCost = Sa[minState] + costMatrix[parentStates[i]][minState];
for (int j = 1; j < Sa.length; j++) {
double c = Sa[j] + costMatrix[parentStates[i]][j];
if (c < minCost) {
minState = j;
minCost = c;
}
}
nodeStates[node.getNumber()][i] = minState;
}
}
for (int i = 0; i < tree.getChildCount(node); i++) {
reconstructStates(tree, tree.getChild(node, i), nodeStates[node.getNumber()]);
}
}
private int minState(double[] s1) {
int minState = 0;
for (int j = 1; j < s1.length; j++) {
if (s1[j] < s1[minState]) minState = j;
}
return minState;
}
private double minScore(double[] s1) {
double minScore = s1[0];
for (int j = 1; j < s1.length; j++) {
if (s1[j] < minScore) minScore = s1[j];
}
return minScore;
}
private double minCost(int i, double[] s1, double[][] costMatrix) {
double[] costRow = costMatrix[i];
double minCost = costRow[0] + s1[0];
for (int j = 1; j < s1.length; j++) {
double cost = costRow[j] + s1[j];
if (cost < minCost) minCost = cost;
}
return minCost;
}
private int minState(double[] s1, int[] set) {
int minState = set[0];
for (int j = 1; j < set.length; j++) {
if (s1[set[j]] < s1[minState]) minState = set[j];
}
return minState;
}
private double minScore(double[] s1, int[] set) {
double minScore = s1[set[0]];
for (int j = 1; j < set.length; j++) {
if (s1[set[j]] < minScore) minScore = s1[set[j]];
}
return minScore;
}
private double minCost(int i, double[] s1, double[][] costMatrix, int[] set) {
double[] costRow = costMatrix[set[i]];
double minCost = costRow[set[0]] + s1[set[0]];
for (int j = 1; j < set.length; j++) {
double cost = costRow[set[j]] + s1[set[j]];
if (cost < minCost) minCost = cost;
}
return minCost;
}
public static void main(String[] argv) {
FlexibleNode tip1 = new FlexibleNode(new Taxon("tip1"));
FlexibleNode tip2 = new FlexibleNode(new Taxon("tip2"));
FlexibleNode tip3 = new FlexibleNode(new Taxon("tip3"));
FlexibleNode tip4 = new FlexibleNode(new Taxon("tip4"));
FlexibleNode tip5 = new FlexibleNode(new Taxon("tip5"));
FlexibleNode node1 = new FlexibleNode();
node1.addChild(tip1);
node1.addChild(tip2);
FlexibleNode node2 = new FlexibleNode();
node2.addChild(tip4);
node2.addChild(tip5);
FlexibleNode node3 = new FlexibleNode();
node3.addChild(tip3);
node3.addChild(node2);
FlexibleNode root = new FlexibleNode();
root.addChild(node1);
root.addChild(node3);
FlexibleTree tree = new FlexibleTree(root);
Patterns patterns = new Patterns(Nucleotides.INSTANCE, tree);
//patterns.addPattern(new int[] {1, 0, 1, 2, 2});
//patterns.addPattern(new int[] {2, 1, 1, 1, 2});
patterns.addPattern(new int[] {2, 3, 1, 3, 3});
FitchParsimony fitch = new FitchParsimony(patterns, false);
SankoffParsimony sankoff = new SankoffParsimony(patterns);
for (int i = 0; i < patterns.getPatternCount(); i++) {
double[] scores = fitch.getSiteScores(tree);
System.out.println("Pattern = " + i);
System.out.println("Fitch:");
System.out.println(" No. Steps = " + scores[i]);
System.out.println(" state(node1) = " + fitch.getStates(tree, node1)[i]);
System.out.println(" state(node2) = " + fitch.getStates(tree, node2)[i]);
System.out.println(" state(node3) = " + fitch.getStates(tree, node3)[i]);
System.out.println(" state(root) = " + fitch.getStates(tree, root)[i]);
scores = sankoff.getSiteScores(tree);
System.out.println("Sankoff:");
System.out.println(" No. Steps = " + scores[i]);
System.out.println(" state(node1) = " + sankoff.getStates(tree, node1)[i]);
System.out.println(" state(node2) = " + sankoff.getStates(tree, node2)[i]);
System.out.println(" state(node3) = " + sankoff.getStates(tree, node3)[i]);
System.out.println(" state(root) = " + sankoff.getStates(tree, root)[i]);
System.out.println();
}
}
}