package dr.evomodel.antigenic.phyloClustering.misc.obsolete;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import dr.evolution.tree.NodeRef;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.MatrixParameter;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.xml.*;
/**
* A Gibbs operator for allocation of items to clusters under a distance dependent Chinese restaurant process.
*
* @author Charles Cheung
* @author Trevor Bedford
*/
public class TreeClusterSequentialSampling extends SimpleMCMCOperator implements GibbsOperator{
//Parameter locationDrift; // no longer need to know
Parameter virusOffsetsParameter;
private double sigmaSq =1;
private int numdata = 0; //NEED TO UPDATE
//private double[] groupSize;
private MatrixParameter mu = null;
private Parameter clusterLabels = null;
private Parameter K = null;
private MatrixParameter virusLocations = null;
private int maxLabel = 0;
private int[] muLabels = null;
private int[] groupSize;
// public ClusterViruses clusterLikelihood = null;
private double numAcceptMoveMu = 0;
private double numProposeMoveMu = 0;
private double numAcceptMoveC = 0;
private double numProposeMoveC = 0;
private int isMoveMu = -1;
private double[] old_vLoc0 ;
private double[] old_vLoc1 ;
private Parameter clusterOffsetsParameter;
private AGLikelihoodTreeCluster clusterLikelihood = null;
private int groupSelectedChange = -1;
private int virusIndexChange = -1;
private double originalValueChange = -1;
private int dimSelectChange = -1;
private double[] mu0_offset;
private Parameter breakPoints = null;
private int binSize=20;
private Parameter status;
private TreeModel treeModel;
int []membershipToClusterLabelIndexes = null;
private int numNodes;
// private int[] piIndicator = new int[numSites];
//public ClusterAlgorithmOperator(MatrixParameter virusLocations, MatrixParameter mu, Parameter clusterLabels, Parameter K, double weight, Parameter virusOffsetsParameter, Parameter locationDrift_in, Parameter clusterOffsetsParameter) {
public TreeClusterSequentialSampling(MatrixParameter virusLocations, MatrixParameter mu, Parameter clusterLabels, Parameter K, double weight, Parameter virusOffsetsParameter, Parameter clusterOffsetsParameter, Parameter breakPointsParameter, Parameter statusParameter, TreeModel treeModel_in, AGLikelihoodTreeCluster clusterLikelihood_in) {
System.out.println("Loading the constructor for Sequential sampler");
this.clusterLikelihood = clusterLikelihood_in;
this.treeModel= treeModel_in;
this.mu = mu;
this.K = K;
this.clusterLabels = clusterLabels;
// this.clusterLikelihood = clusterLikelihood;
this.virusLocations = virusLocations;
this.virusOffsetsParameter = virusOffsetsParameter;
// this.locationDrift = locationDrift_in; //no longer need
this.clusterOffsetsParameter = clusterOffsetsParameter;
this.breakPoints= breakPointsParameter;
this.status = statusParameter;
numdata = virusOffsetsParameter.getSize();
System.out.println("numdata="+ numdata);
int K_int = (int) K.getParameterValue(0);
System.out.println("K_int=" + K_int);
groupSize = new int[binSize];
for(int i=0; i < binSize; i++){
groupSize[i] = 0;
}
for(int i=0; i < numdata; i++){
//System.out.println("i="+ i);
int index = (int) clusterLabels.getParameterValue(i);
groupSize[ index]++;
}
for(int i=0; i < numdata;i++){
if(maxLabel < (int) clusterLabels.getParameterValue(i)){
maxLabel = (int) clusterLabels.getParameterValue(i);
}
}
//NEED maxGROUP
//for(int i=0; i < K_int; i++){
//System.out.println("groupSize=" + groupSize[i]);
//}
muLabels = new int[binSize];
for(int i=0; i < maxLabel; i++){
int j=0;
if(groupSize[i] >0){
muLabels[j] = i;
j++;
}
}
//muLabels ...
setWeight(weight);
System.out.println("Finished loading the constructor for ClusterAlgorithmOperator");
}
/**
* change the parameter and return the log hastings ratio.
*/
public final double doOperation() {
System.out.println("do operation of sequential sampling");
setMembershipToClusterLabelIndexes(); //run once in case the tree changes.
numNodes = treeModel.getNodeCount();
updateK();
int K_int = (int) K.getParameterValue(0);
//System.out.println("K_int is " + K_int);
int[] cutNode = new int[K_int]; //to store which nodes to add...
int[] oldclusterLabelArray = new int[numNodes]; //everything 0.
int[] clusterLabelArray = null;
int[] onPoints = new int[K_int];
int numOn=0;
for(int i=0; i < binSize; i++){
if( (int)status.getParameterValue(i) ==1 ){
onPoints[numOn] = i;
numOn++;
}
}
//Verify relationship
//P(H|E0=a, Y, mu) = sum P(H|E0=a, E1=e1, Y, mu) x P(E1=e1| E0=a)
cutNode = new int[1];
cutNode[0] = 785;
onPoints = new int[1];
onPoints[0] = 0;
//now translate the cutNodes into the breakpoints that are on [although here we don't use the indicators - elsewhere does
resetStatusAndBreakpointsGivenCutNodes(cutNode, onPoints);//need to reset the status breakpoints by the testCutNode
//use the tree to re-partition according to the change.
clusterLabelArray = setClusterLabelsByTestCutNodeByNodeOrder(cutNode);
relabelClusterLabels(clusterLabelArray, oldclusterLabelArray); //will move it out
for(int i=0; i < numdata; i++){
clusterLabels.setParameterValue(i, clusterLabelArray[i]); //set cluster label parameter
}
//setVirusLocationAndOffsets(); //set virus locations, given the clusterlabels parameter
setVirusLocationAutoCorrelatedModel();
double topLogProb = clusterLikelihood.getLogLikelihood();
System.out.println(topLogProb);
//now test the sum...
cutNode = new int[2];
cutNode[0] = 785;
onPoints = new int[2];
onPoints[0] = 0;
onPoints[1] = 1;
double []logNumerator = new double[numNodes];
//calculate the distribution for calculating introducing an excision point in each node
for(int curTest=0; curTest < numNodes; curTest++){
int hasBeenAdded = checkSiteHasBeenAddedToOnIndicators(curTest); //check if a site has already been added
if(hasBeenAdded ==0){
cutNode = new int[2];
cutNode[0] = 785;
cutNode[1] = curTest;
resetStatusAndBreakpointsGivenCutNodes(cutNode, onPoints);//need to reset the status breakpoints by the testCutNode
//use the tree to re-partition according to the change.
clusterLabelArray = setClusterLabelsByTestCutNodeByNodeOrder(cutNode); //note that instead of using the indicators, it uses the testCutNodes directly
relabelClusterLabels(clusterLabelArray, oldclusterLabelArray); //will move it out
//set cluster label parameter for testing
for(int i=0; i < numdata; i++){
clusterLabels.setParameterValue(i, clusterLabelArray[i]);
}
//setVirusLocationAndOffsets(); //this uses the clusterLabels parameter
setVirusLocationAutoCorrelatedModel(); //which depends on the status and breakpoints
logNumerator[curTest] += clusterLikelihood.getLogLikelihood() ; //Calculate likelihood
}
else{
logNumerator[curTest] = Double.NEGATIVE_INFINITY;
System.out.println("Don't calculate for node" + cutNode[0]);
}
} //finished curTest
double answer = verifyAssumption(topLogProb, logNumerator);
System.out.println("the ratio is " + answer);
System.exit(0);
//switch
double tmpMu1 = mu.getParameter(onPoints[7] +1).getParameterValue(0);
double tmpMu2 = mu.getParameter(onPoints[7] +1).getParameterValue(1);
mu.getParameter(onPoints[7] +1).setParameterValue(0, mu.getParameter(onPoints[8] +1).getParameterValue(0));
mu.getParameter(onPoints[7] +1).setParameterValue(1, mu.getParameter(onPoints[8] +1).getParameterValue(1));
mu.getParameter(onPoints[8] +1).setParameterValue(0, tmpMu1);
mu.getParameter(onPoints[8] +1).setParameterValue(1, tmpMu2);
for(int curNode=8; curNode < K_int; curNode++){
cutNode[0] = 785;
cutNode[1] = 775;
cutNode[2] = 763;
cutNode[3] = 697;
cutNode[4] = 747;
cutNode[5] = 679;
cutNode[6] = 662;
//cutNode[7] = 521; //although 526 will be better than 521
cutNode[7] = 638;
//calculate the conditional distribution of the curNode, given the current set of nodes
double []logNumeratorProb = new double[numNodes];
//calculate the distribution for calculating introducing an excision point in each node
for(int curTest=0; curTest < numNodes; curTest++){
int hasBeenAdded = checkSiteHasBeenAddedToOnIndicators(curTest); //check if a site has already been added
if(hasBeenAdded ==0){
int[] testCutNode = new int[curNode+1]; //create the testCutNode, with adding the current test node
for(int i=0; i < curNode; i++){
testCutNode[i] = cutNode[i];
}
testCutNode[curNode] = curTest;
resetStatusAndBreakpointsGivenCutNodes(testCutNode, onPoints);//need to reset the status breakpoints by the testCutNode
/*
System.out.print("Currently selected: [");
for(int i=0; i < binSize; i++){
if((int) status.getParameterValue(i) ==1){
System.out.print( (int) breakPoints.getParameterValue(i) + ",");
}
}
System.out.println("]");
*/
//I suspect I need to change this as I modify the code..
//set the indicators, based on the cutnodes
/*
int addCount=0;
for(int i=0; i < binSize; i++){
if( (int)status.getParameterValue(i) == 1){
breakPoints.setParameterValue(i, testCutNode[addCount]);
addCount++;
}
if(addCount == (curNode+1)){
break;
}
}
*/
//System.out.println("currently added " + addCount + " nodes");
//use the tree to re-partition according to the change.
clusterLabelArray = setClusterLabelsByTestCutNodeByNodeOrder(testCutNode); //note that instead of using the indicators, it uses the testCutNodes directly
relabelClusterLabels(clusterLabelArray, oldclusterLabelArray); //will move it out
//set cluster label parameter for testing
for(int i=0; i < numdata; i++){
clusterLabels.setParameterValue(i, clusterLabelArray[i]);
}
//setVirusLocationAndOffsets(); //this uses the clusterLabels parameter
setVirusLocationAutoCorrelatedModel(); //which depends on the status and breakpoints
if(curNode == 0 && curTest == 0){
for(int i=0; i < numdata; i++){
Parameter v = virusLocations.getParameter(i);
//v.setParameterValue(0, 0);
//v.setParameterValue(1, 0);
}
}
logNumeratorProb[curTest] = clusterLikelihood.getLogLikelihood(); //Calculate likelihood
//if(curNode == 0 && curTest == 0){
//for(int i=0; i < numdata; i++){
// Parameter v = virusLocations.getParameter(i);
// System.out.print(v.getParameterValue(0) +"," + v.getParameterValue(1)+"\t");
//}
//System.out.println("");
//System.out.println(" * " + logNumeratorProb[curTest]);
//System.exit(0);
//}
}
else{
logNumeratorProb[curTest] = Double.NEGATIVE_INFINITY; //dummy probability
}
} //finished curTest
double []condDistribution = calculateConditionalProbabilityGivenLogNumeratorProb(logNumeratorProb);
for(int i=0; i < numNodes; i++){
// if(condDistribution[i] > 0.0000001){
System.out.println("node " + i + " p=" + condDistribution[i]);
// }
}
System.out.println("===============================");
// System.exit(0);
int site_add = MathUtils.randomChoicePDF(condDistribution); //sample a site, given the conditioanl distribution
//update the cutNode using a temporary array newCutNode - first copy the existing element, then add the new site
int[] newCutNode = new int[curNode+1];
for(int i=0; i < curNode; i++){
newCutNode[i] = cutNode[i];
}
newCutNode[curNode] = site_add; //add the new site.
cutNode = newCutNode;
//now , after adding the node to cutNode, we need to update clusterLabel such that the labeling is consistent with the old cluster labels..
// because ie. cluster i always get mu i.
clusterLabelArray = setClusterLabelsByTestCutNodeByNodeOrder(newCutNode);
relabelClusterLabels(clusterLabelArray, oldclusterLabelArray); //will move it out
oldclusterLabelArray = clusterLabelArray; // keep the oldcluster label to build on top of it..
} //curNode
//====================================================================================================
//After finishing the proposal
//====================================================================================================
//Display:
printCutNode(cutNode);
//now translate the cutNodes into the breakpoints that are on [although here we don't use the indicators - elsewhere does
resetStatusAndBreakpointsGivenCutNodes(cutNode, onPoints);//need to reset the status breakpoints by the testCutNode
/*
int addCount=0;
for(int i=0; i < binSize; i++){
if( (int)status.getParameterValue(i) == 1){
breakPoints.setParameterValue(i, cutNode[addCount]);
addCount++;
}
}
*/
//use the tree to re-partition according to the change.
clusterLabelArray = setClusterLabelsByTestCutNodeByNodeOrder(cutNode);
relabelClusterLabels(clusterLabelArray, oldclusterLabelArray); //will move it out
for(int i=0; i < numdata; i++){
clusterLabels.setParameterValue(i, clusterLabelArray[i]); //set cluster label parameter
}
//setVirusLocationAndOffsets(); //set virus locations, given the clusterlabels parameter
setVirusLocationAutoCorrelatedModel();
System.out.println(clusterLikelihood.getLogLikelihood());
//add manually to test...
cutNode[0] = 785;
cutNode[1] = 775;
cutNode[2] = 763;
cutNode[3] = 697;
cutNode[4] = 747;
cutNode[5] = 679;
cutNode[6] = 662;
//cutNode[7] = 521;
//cutNode[8] = 638;
cutNode[7] = 638;
cutNode[8] = 521;
resetStatusAndBreakpointsGivenCutNodes(cutNode, onPoints);//need to reset the status breakpoints by the testCutNode
//use the tree to re-partition according to the change.
clusterLabelArray = setClusterLabelsByTestCutNodeByNodeOrder(cutNode);
relabelClusterLabels(clusterLabelArray, oldclusterLabelArray); //will move it out
for(int i=0; i < numdata; i++){
clusterLabels.setParameterValue(i, clusterLabelArray[i]); //set cluster label parameter
}
//setVirusLocationAndOffsets(); //set virus locations, given the clusterlabels parameter
setVirusLocationAutoCorrelatedModel();
System.out.println(clusterLikelihood.getLogLikelihood());
System.exit(0);
//NEED TO SET THE OTHER BREAKPOINTS TO MAKE SURE THERE ARE NO DUPLICATES AT THE END...
//SO OTHER PROPOSALS WOULD FUNCTION OK OK
return(Double.POSITIVE_INFINITY); //it should be anything... always accept for the Gibbs move.
}
private void resetStatusAndBreakpointsGivenCutNodes(int[] testCutNode, int[] onPoints) {
for(int i=0; i < binSize; i++){
status.setParameterValue(i, 0);
breakPoints.setParameterValue(i, -1);
}
int numOn = testCutNode.length;
int countOn=0;
for(int i=0; i < binSize; i++){
if(countOn < numOn){
status.setParameterValue(onPoints[countOn], 1);
breakPoints.setParameterValue(onPoints[countOn],testCutNode[countOn]); //reset breakPoints accordingly
countOn++;
}
}
}
private void updateK() {
//K is changed accordingly..
int K_count = 0; //K_int gets updated
for(int i=0; i < binSize; i++){
K_count += (int) status.getParameterValue(i);
}
//System.out.println("K now becomes " + K_count);
K.setParameterValue(0, K_count); //update
}
private void printCutNode(int[] cutNode) {
System.out.print("sampled:\t[");
for(int i=0; i < cutNode.length ; i++){
System.out.print(cutNode[i] + ",");
}
System.out.println("]");
}
private double[] calculateConditionalProbabilityGivenLogNumeratorProb(
double[] logNumeratorProb) {
int numNodes = logNumeratorProb.length;
double maxLogProb = logNumeratorProb[0];
for(int i=0; i < numNodes; i++ ){
if(logNumeratorProb[i] > maxLogProb){
maxLogProb = logNumeratorProb[i];
}
}
double sumLogDenominator = 0;
for(int i=0; i < numNodes; i++){
if(logNumeratorProb[i] != Double.NEGATIVE_INFINITY){
sumLogDenominator += Math.exp((logNumeratorProb[i]-maxLogProb));
}
}
sumLogDenominator = Math.log(sumLogDenominator) + maxLogProb;
double sumProb = 0;
double []condProb = new double[numNodes];
for(int i=0; i < numNodes; i++){
condProb[i] = Math.exp( logNumeratorProb[i] - sumLogDenominator );
//System.out.println("condProb of site " + i + " = " + condProb[i]);
sumProb +=condProb[i];
if(condProb[i] > 0.01){
// System.out.println("**site " + i + " with prob=" + condProb[i] + " steps from previous=" + numStepsFromOrigin[i]);
}
}
return(condProb);
}
//Expect the ratio to be 1.
private double verifyAssumption( double topLogMarginal,double[] logNumeratorProb) {
for(int i=0; i < numNodes; i++){
System.out.println(logNumeratorProb[i]);
}
int numNodes = logNumeratorProb.length;
double maxLogProb = logNumeratorProb[0];
for(int i=0; i < numNodes; i++ ){
if(logNumeratorProb[i] > maxLogProb){
maxLogProb = logNumeratorProb[i];
}
}
System.out.println("maxLogProb = " + maxLogProb);
double sumLogDenominator = 0;
for(int i=0; i < numNodes; i++){
if(logNumeratorProb[i] != Double.NEGATIVE_INFINITY){
sumLogDenominator += Math.exp((logNumeratorProb[i]-maxLogProb));
}
}
System.out.println("tmp sum = " + sumLogDenominator);
sumLogDenominator = Math.log(sumLogDenominator) + maxLogProb;
System.out.println("topLogMarginal = " + topLogMarginal);
System.out.println("sumLogDenominator = " + sumLogDenominator);
double ratio = Math.exp( Math.log(numNodes-1) + topLogMarginal - sumLogDenominator );
return(ratio);
}
private int checkSiteHasBeenAddedToOnIndicators(int curTest){
int hasBeenAdded=0;
for(int i=0; i < binSize; i++){
if((int)status.getParameterValue(i) == 1 ){
if( (int) breakPoints.getParameterValue(i) == curTest){
hasBeenAdded=1;
break;
}
}
}
return(hasBeenAdded);
}
private void setVirusLocationAndOffsets() {
//change the mu in the toBin and fromBIn
//borrow from getLogLikelihood:
double[] meanYear = new double[binSize];
double[] groupCount = new double[binSize];
for(int i=0; i < numdata; i++){
int label = (int) clusterLabels.getParameterValue(i);
double year = 0;
if (virusOffsetsParameter != null) {
// System.out.print("virus Offeset Parameter present"+ ": ");
// System.out.print( virusOffsetsParameter.getParameterValue(i) + " ");
// System.out.print(" drift= " + drift + " ");
year = virusOffsetsParameter.getParameterValue(i); //just want year[i]
//make sure that it is equivalent to double offset = year[virusIndex] - firstYear;
}
else{
System.out.println("virus Offeset Parameter NOT present. We expect one though. Something is wrong.");
}
meanYear[ label] = meanYear[ label] + year;
groupCount[ label ] = groupCount[ label ] +1;
}
for(int i=0; i < binSize; i++){
if(groupCount[i] > 0){
meanYear[i] = meanYear[i]/groupCount[i];
}
//System.out.println(meanYear[i]);
}
mu0_offset = new double[binSize];
//double[] mu1 = new double[maxLabel];
//System.out.println("maxLabel=" + maxLabel);
//now, change the mu..
for(int i=0; i < binSize; i++){
//System.out.println(meanYear[i]*beta);
mu0_offset[i] = meanYear[i];
//System.out.println("group " + i + "\t" + mu0_offset[i]);
}
// System.out.println("=====================");
//Set the vLoc to be the corresponding mu values , and clusterOffsetsParameter to be the corresponding offsets
//virus in the same cluster has the same position
for(int i=0; i < numdata; i++){
int label = (int) clusterLabels.getParameterValue(i);
Parameter vLoc = virusLocations.getParameter(i);
//setting the virus locs to be equal to the corresponding mu
double muValue = mu.getParameter(label).getParameterValue(0);
vLoc.setParameterValue(0, muValue);
double muValue2 = mu.getParameter(label).getParameterValue(1);
vLoc.setParameterValue(1, muValue2);
//System.out.println("vloc="+ muValue + "," + muValue2);
}
for(int i=0; i < numdata; i++){
int label = (int) clusterLabels.getParameterValue(i);
//if we want to apply the mean year virus cluster offset to the cluster
if(clusterOffsetsParameter != null){
//setting the clusterOffsets to be equal to the mean year of the virus cluster
// by doing this, the virus changes cluster AND updates the offset simultaneously
clusterOffsetsParameter.setParameterValue( i , mu0_offset[label]);
}
// System.out.println("mu0_offset[label]=" + mu0_offset[label]);
// System.out.println("clusterOffsets " + i +" now becomes =" + clusterOffsetsParameter.getParameterValue(i) );
}
// System.out.println("===The on nodes===");
// for(int i=0; i < binSize; i++){
// if((int) excisionPoints.getParameterValue(i) == 1){
// System.out.println("Cluster node " + i + " = " + (int) indicators.getParameterValue(i) + "\tstatus=" + (int) excisionPoints.getParameterValue(i));
// }
// }
}
private void setVirusLocationAutoCorrelatedModel() {
int numNodes = treeModel.getNodeCount();
double[][] nodeloc = new double[numNodes][2];
//new - Trevor's autocorrelated model.
// System.out.println("Autocorrelated tree model");
//given mu, excision points, and which ones are on...
int[] nodeStatus = new int[numNodes];
for(int i=0; i < numNodes; i ++){
nodeStatus[i] = -1;
}
//convert to easy process format.
for(int i=0; i < (binSize ); i++){
if((int) status.getParameterValue(i) ==1){
nodeStatus[(int)breakPoints.getParameterValue(i)] = i;
}
}
//Testing:
//muValue[0] = 1;
//muValue2[0] = 1.5;
//nodeStatus[696] = 0;
//muValue[1] = 10;
//muValue2[1] = 20;
//nodeStatus[607] = 1;
//muValue[2] = 200;
//muValue2[2] = 300;
//process the tree and get the vLoc of the viruses..
//breadth first depth first..
NodeRef cNode = treeModel.getRoot();
LinkedList<NodeRef> visitlist = new LinkedList<NodeRef>();
visitlist.add(cNode);
int countProcessed=0;
while(visitlist.size() > 0){
countProcessed++;
//assign value to the current node...
if(treeModel.getParent(cNode) == null){
Parameter curMu = mu.getParameter(0);
nodeloc[cNode.getNumber()][0] = curMu.getParameterValue(0);
nodeloc[cNode.getNumber() ][1] = curMu.getParameterValue(1);
}
else{
nodeloc[cNode.getNumber()][0] = nodeloc[treeModel.getParent(cNode).getNumber()][0];
nodeloc[cNode.getNumber()][1] = nodeloc[treeModel.getParent(cNode).getNumber()][1];
if(nodeStatus[cNode.getNumber()] != -1){
//System.out.println("Run new location");
Parameter curMu = mu.getParameter(nodeStatus[cNode.getNumber()] +1);
nodeloc[cNode.getNumber()][0] += curMu.getParameterValue(0);
nodeloc[cNode.getNumber()][1] += curMu.getParameterValue(1);
}
}
//add all the children to the queue
for(int childNum=0; childNum < treeModel.getChildCount(cNode); childNum++){
NodeRef node= treeModel.getChild(cNode,childNum);
visitlist.add(node);
}
visitlist.pop(); //now that we have finished visiting this node, pops it out of the queue
if(visitlist.size() > 0){
cNode = visitlist.getFirst(); //set the new first node in the queue to visit
}
}
//write the virus locations
for(int i=0; i < numdata; i++){
Parameter vLocParameter = virusLocations.getParameter(i);
vLocParameter.setParameterValue(0, nodeloc[membershipToClusterLabelIndexes[i]][0]);
vLocParameter.setParameterValue(1, nodeloc[membershipToClusterLabelIndexes[i]][1]);
}
//for(int i=0; i < numdata; i++){
//Parameter vLocP= virusLocations.getParameter(i);
//System.out.println("virus " + vLocP.getId() + "\t" + vLocP.getParameterValue(0) + "," + vLocP.getParameterValue(1) );
//}
//System.out.println("Processed " + countProcessed + " nodes");
//System.out.println("Done");
//System.exit(0);
}
private void relabelClusterLabels(int[] clusterLabel, int[] oldclusterLabel) {
int maxOldLabel = 0;
for(int i=0; i < oldclusterLabel.length; i++){
if(maxOldLabel < oldclusterLabel[i]){
maxOldLabel = oldclusterLabel[i];
}
}
Map<Integer, Integer> m = new HashMap<Integer, Integer>();
int[] isOldUsed = new int[ clusterLabel.length ]; //an overkill - basically just need the max label in the old cluster
for(int i=0; i < clusterLabel.length; i++){
if(m.get(new Integer(clusterLabel[i])) == null ){
if(isOldUsed[oldclusterLabel[i]] == 0){
m.put(new Integer(clusterLabel[i]), new Integer(oldclusterLabel[i]));
isOldUsed[oldclusterLabel[i]] = 1;
}
else{
maxOldLabel++;
m.put(new Integer(clusterLabel[i]), new Integer(maxOldLabel));
}
}
clusterLabel[i] = m.get(new Integer( clusterLabel[i])).intValue();
}
}
private int[] setClusterLabelsByTestCutNodeByNodeOrder(int[] testCutNode) {
int []membership = determine_membershipByNodeOrder(treeModel, testCutNode, testCutNode.length); // the time consuming step here.
//The assumption that the first nodes being external node corresponding to the cluster labels IS FALSE
//so I have to search for the matching indexes
// for(int i=0; i < numdata; i++){
// clusterLabels.setParameterValue( i, membership[membershipToClusterLabelIndexes[i]]);
//}
//to speed up the code
int[] clusterLabel = new int[numdata];
for(int i=0; i < numdata; i++){
clusterLabel[i] = membership[membershipToClusterLabelIndexes[i]];
}
return(clusterLabel);
}
private void setMembershipToClusterLabelIndexes(){
int numNodes = treeModel.getNodeCount();
//I suspect this is an expensive operation, so I don't want to do it many times,
//which is also unnecessary - MAY have to update whenever a different tree is used.
membershipToClusterLabelIndexes = new int[numdata];
for(int i=0; i < numdata; i++){
Parameter v = virusLocations.getParameter(i);
String curName = v.getParameterName();
// System.out.println(curName);
int isFound = 0;
for(int j=0; j < numNodes; j++){
String treeId = treeModel.getTaxonId(j);
if(curName.equals(treeId) ){
// System.out.println(" isFound at j=" + j);
membershipToClusterLabelIndexes[i] = j;
isFound=1;
break;
}
}
if(isFound ==0){
System.out.println("not found. Exit now.");
System.exit(0);
}
}
}
//private void setClusterLabelsByTestCutNode(int[] testCutNode) {
private int[] setClusterLabelsByTestCutNode(int[] testCutNode) {
int []membership = determine_membership(treeModel, testCutNode, testCutNode.length); // the time consuming step here.
//The assumption that the first nodes being external node corresponding to the cluster labels IS FALSE
//so I have to search for the matching indexes
// for(int i=0; i < numdata; i++){
// clusterLabels.setParameterValue( i, membership[membershipToClusterLabelIndexes[i]]);
//}
//to speed up the code
int[] clusterLabel = new int[numdata];
for(int i=0; i < numdata; i++){
clusterLabel[i] = membership[membershipToClusterLabelIndexes[i]];
}
return(clusterLabel);
}
//This function uses the breakPoints, but the doOperation doesn't use this anymore..
//instead, it uses the cutNodes explicitly.
private void setClusterLabels(int K_int) {
int numNodes = treeModel.getNodeCount();
int[] cutNodes = new int[K_int];
int cutNum = 0;
String content = "";
for(int i=0; i < binSize; i++){
if( (int) status.getParameterValue( i ) ==1 ){
cutNodes[cutNum] = (int) breakPoints.getParameterValue(i);
content += (int) breakPoints.getParameterValue(i) + ",";
cutNum++;
}
}
// System.out.println(content);
if(cutNum != K_int){
System.out.println("cutNum != K_int. we got a problem");
}
// for(int i=0; i < K_int; i++){
// System.out.println(cutNodes[i]);
// }
//int []membership = determine_membership(treeModel, cutNodes, K_int-1);
int []membership = determine_membership(treeModel, cutNodes, K_int);
double uniqueCode = 0;
for(int i=0; i < numNodes; i++){
uniqueCode += membership[i]*i;
}
// System.out.println(" sum = " + uniqueCode);
// System.out.println("number of nodes = " + treeModel.getNodeCount());
// for(int i=0; i < treeModel.getNodeCount(); i++){
// System.out.println(membership[i]);
// }
//System.out.println("Done");
// for(int i=0; i < numdata; i++){
// Parameter v = virusLocations.getParameter(i);
// String curName = v.getParameterName();
// System.out.println("i=" + i + " = " + curName);
// }
// for(int j=0; j < numdata; j++){
// System.out.println("j=" + j + " = " + treeModel.getTaxonId(j));
// }
// Parameter vv = virusLocations.getParameter(0);
// String curNamev = vv.getParameterName();
// System.out.println(curNamev + " and " +treeModel.getTaxonId(392) );
//System.out.println( curNamev.equals(treeModel.getTaxonId(392) ) );
//System.exit(0);
// System.out.println("numNodes=" + numNodes);
// System.exit(0);
//create dictionary:
//I suspect this is an expensive operation, so I don't want to do it many times,
//which is also unnecessary - MAY have to update whenever a different tree is used.
int []membershipToClusterLabelIndexes = new int[numdata];
for(int i=0; i < numdata; i++){
Parameter v = virusLocations.getParameter(i);
String curName = v.getParameterName();
// System.out.println(curName);
int isFound = 0;
for(int j=0; j < numNodes; j++){
String treeId = treeModel.getTaxonId(j);
if(curName.equals(treeId) ){
// System.out.println(" isFound at j=" + j);
membershipToClusterLabelIndexes[i] = j;
isFound=1;
break;
}
}
if(isFound ==0){
System.out.println("not found. Exit now.");
System.exit(0);
}
}
// System.exit(0);
// for(int i=0; i < numdata; i++){
// System.out.println(membershipToClusterLabelIndexes[i]);
// }
// System.exit(0);
for(int i=0; i < numdata; i++){
//The assumption that the first nodes being external node corresponding to the cluster labels IS FALSE
//so I have to search for the matching indexes
Parameter vloc = virusLocations.getParameter(i);
//must uncomment out because this sets the new partitioning ... now i am doing code testing.
clusterLabels.setParameterValue( i, membership[membershipToClusterLabelIndexes[i]]);
//System.out.println(vloc.getParameterName() + " i="+ i + " membership=" + (int) clusterLabels.getParameterValue(i));
// Parameter v = virusLocations.getParameter(i);
// System.out.println(v.getParameterName());
}
}
private static boolean isCutNode(int number, int cutNodes[], int numCut) {
if(numCut > 0){
for(int i=0; i < numCut; i++){
if(number == cutNodes[i]){
return true;
}
}
}
return false;
}
//traverse down the tree, top down, do calculation
static int[] determine_membership(TreeModel treeModel, int[] cutNodes, int numCuts){
NodeRef root = treeModel.getRoot();
int numClusters = 1;
LinkedList<NodeRef> list = new LinkedList<NodeRef>();
list.addFirst(root);
int[] membership = new int[treeModel.getNodeCount()];
for(int i=0; i < treeModel.getNodeCount(); i++){
membership[i] = -1;
}
membership[root.getNumber()] = 0; //root always given the first cluster
while(!list.isEmpty()){
//do things with the current object
NodeRef curElement = list.pop();
//String content = "node #" + curElement.getNumber() +", taxon=" + treeModel.getNodeTaxon(curElement) + " and parent is = " ;
String content = "node #" + curElement.getNumber() +", taxon= " ;
if(treeModel.getNodeTaxon(curElement)== null){
content += "internal node\t";
}
else{
content += treeModel.getNodeTaxon(curElement).getId() + "\t";
//content += treeModel.getTaxonIndex(treeModel.getNodeTaxon(curElement)) + "\t";
}
if(treeModel.getParent(curElement)== null){
//content += "no parent";
}
else{
//content += "parent node#=" + treeModel.getParent(curElement).getNumber();
}
//cluster assignment:
if(!treeModel.isRoot(curElement)){
if(isCutNode(curElement.getNumber(), cutNodes, numCuts)){
//if(isCutNode(curElement.getNumber())){
numClusters++ ;
membership[ curElement.getNumber() ] = numClusters - 1;
}
else{
//inherit from parent's cluster assignment
membership[curElement.getNumber()] = membership[treeModel.getParent(curElement).getNumber()];
}
}//is not Root
content += " cluster = " + membership[curElement.getNumber()] ;
// System.out.println(content);
for(int childNum=0; childNum < treeModel.getChildCount(curElement); childNum++){
list.addFirst(treeModel.getChild(curElement,childNum));
}
}
return(membership);
}
//traverse down the tree, top down, do calculation
static int[] determine_membershipByNodeOrder(TreeModel treeModel, int[] cutNodes, int numCuts){
Map<Integer, Integer> m = new HashMap<Integer, Integer>();
for(int i=0; i < numCuts; i++){
m.put(new Integer(cutNodes[i]), new Integer(i+1));
// System.out.println(cutNodes[i] + "\t" + (i+1) );
}
NodeRef root = treeModel.getRoot();
int numClusters = 1;
LinkedList<NodeRef> list = new LinkedList<NodeRef>();
list.addFirst(root);
int[] membership = new int[treeModel.getNodeCount()];
for(int i=0; i < treeModel.getNodeCount(); i++){
membership[i] = -1;
}
membership[root.getNumber()] = 0; //root always given the first cluster
while(!list.isEmpty()){
//do things with the current object
NodeRef curElement = list.pop();
//String content = "node #" + curElement.getNumber() +", taxon=" + treeModel.getNodeTaxon(curElement) + " and parent is = " ;
String content = "node #" + curElement.getNumber() +", taxon= " ;
if(treeModel.getNodeTaxon(curElement)== null){
content += "internal node\t";
}
else{
content += treeModel.getNodeTaxon(curElement).getId() + "\t";
//content += treeModel.getTaxonIndex(treeModel.getNodeTaxon(curElement)) + "\t";
}
if(treeModel.getParent(curElement)== null){
//content += "no parent";
}
else{
//content += "parent node#=" + treeModel.getParent(curElement).getNumber();
}
//cluster assignment:
if(!treeModel.isRoot(curElement)){
if(isCutNode(curElement.getNumber(), cutNodes, numCuts)){
//if(isCutNode(curElement.getNumber())){
//numClusters++ ;
//membership[ curElement.getNumber() ] = numClusters - 1;
// System.out.println("get: curElement" + curElement.getNumber() + "\t" + m.get(new Integer( curElement.getNumber())));
membership[ curElement.getNumber()] = m.get(new Integer( curElement.getNumber()));
}
else{
//inherit from parent's cluster assignment
membership[curElement.getNumber()] = membership[treeModel.getParent(curElement).getNumber()];
}
}//is not Root
content += " cluster = " + membership[curElement.getNumber()] ;
// System.out.println(content);
for(int childNum=0; childNum < treeModel.getChildCount(curElement); childNum++){
list.addFirst(treeModel.getChild(curElement,childNum));
}
}
return(membership);
}
public void accept(double deviation) {
super.accept(deviation);
/*
if(isMoveMu==1){
numAcceptMoveMu++;
numProposeMoveMu++;
System.out.println("% accept move Mu = " + numAcceptMoveMu/(double)numProposeMoveMu);
}
else{
numAcceptMoveC++;
numProposeMoveC++;
System.out.println("% accept move C = " + numAcceptMoveC/(double)numProposeMoveC);
}
*/
// if(virusIndexChange <5){
// System.out.println(" - Accepted!");
// }
}
public void reject(){
super.reject();
/*
//manually change mu back..
if(isMoveMu==1){
mu.getParameter(groupSelectedChange).setParameterValue(dimSelectChange, originalValueChange);
}
//manually change all the affected vLoc back...
for(int i=0; i < numdata; i++){
int label = (int) clusterLabels.getParameterValue(i);
Parameter vLoc = virusLocations.getParameter(i);
// double muValue = mu.getParameter(label).getParameterValue(0);
// vLoc.setParameterValue(0, muValue);
// double muValue2 = mu.getParameter(label).getParameterValue(1);
// vLoc.setParameterValue(1, muValue2);
clusterOffsetsParameter.setParameterValue( i , mu0_offset[label]);
}
*/
/*
if(isMoveMu==1){
numProposeMoveMu++;
System.out.println("% accept move Mu = " + numAcceptMoveMu/(double)numProposeMoveMu);
}
else{
numProposeMoveC++;
System.out.println("% accept move C = " + numAcceptMoveC/(double)numProposeMoveC);
}
*/
//if(virusIndexChange < 5){
System.out.println(" * Rejected!");
//}
/*
for(int i=0; i < numdata; i++){
Parameter vLoc = virusLocations.getParameter(i);
if( vLoc.getParameterValue(0) != old_vLoc0[i]){
System.out.println("virus " + i + " is different: " + vLoc.getParameterValue(0) + " and " + old_vLoc0[i]);
}
//System.out.println(old_vLoc0[i] + ", " + old_vLoc1[i]);
vLoc.setParameterValue(0, old_vLoc0[i]);
vLoc.setParameterValue(1, old_vLoc1[i]);
}
*/
//System.exit(0);
}
public final static String TREE_CLUSTERSEQUENTIAL_OPERATOR = "TreeClusterSequentialSampling";
//MCMCOperator INTERFACE
public final String getOperatorName() {
return TREE_CLUSTERSEQUENTIAL_OPERATOR;
}
public final void optimize(double targetProb) {
throw new RuntimeException("This operator cannot be optimized!");
}
public boolean isOptimizing() {
return false;
}
public void setOptimizing(boolean opt) {
throw new RuntimeException("This operator cannot be optimized!");
}
public double getMinimumAcceptanceLevel() {
return 0.1;
}
public double getMaximumAcceptanceLevel() {
return 0.4;
}
public double getMinimumGoodAcceptanceLevel() {
return 0.20;
}
public double getMaximumGoodAcceptanceLevel() {
return 0.30;
}
public String getPerformanceSuggestion() {
if (Utils.getAcceptanceProbability(this) < getMinimumAcceptanceLevel()) {
return "";
} else if (Utils.getAcceptanceProbability(this) > getMaximumAcceptanceLevel()) {
return "";
} else {
return "";
}
}
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public final static String VIRUSLOCATIONS = "virusLocations";
public final static String MU = "mu";
public final static String CLUSTERLABELS = "clusterLabels";
public final static String K = "k";
public final static String OFFSETS = "offsets";
// public final static String LOCATION_DRIFT = "locationDrift"; //no longer need
public final static String CLUSTER_OFFSETS = "clusterOffsetsParameter";
public final static String INDICATORS = "indicators";
public final static String EXCISION_POINTS = "excisionPoints";
public String getParserName() {
return TREE_CLUSTERSEQUENTIAL_OPERATOR;
}
/* (non-Javadoc)
* @see dr.xml.AbstractXMLObjectParser#parseXMLObject(dr.xml.XMLObject)
*/
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
//System.out.println("Parser run. Exit now");
//System.exit(0);
double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT);
XMLObject cxo = xo.getChild(VIRUSLOCATIONS);
MatrixParameter virusLocations = (MatrixParameter) cxo.getChild(MatrixParameter.class);
cxo = xo.getChild(MU);
MatrixParameter mu = (MatrixParameter) cxo.getChild(MatrixParameter.class);
cxo = xo.getChild(CLUSTERLABELS);
Parameter clusterLabels = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(K);
Parameter k = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(OFFSETS);
Parameter offsets = (Parameter) cxo.getChild(Parameter.class);
// cxo = xo.getChild(LOCATION_DRIFT);
// Parameter locationDrift = (Parameter) cxo.getChild(Parameter.class);
Parameter clusterOffsetsParameter = null;
if (xo.hasChildNamed(CLUSTER_OFFSETS)) {
clusterOffsetsParameter = (Parameter) xo.getElementFirstChild(CLUSTER_OFFSETS);
}
cxo = xo.getChild(INDICATORS);
Parameter indicators = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(EXCISION_POINTS);
Parameter excisionPoints = (Parameter) cxo.getChild(Parameter.class);
TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class);
AGLikelihoodTreeCluster agLikelihood = (AGLikelihoodTreeCluster) xo.getChild(AGLikelihoodTreeCluster.class);
//return new ClusterAlgorithmOperator(virusLocations, mu, clusterLabels, k, weight, offsets, locationDrift, clusterOffsetsParameter);
return new TreeClusterSequentialSampling(virusLocations, mu, clusterLabels, k, weight, offsets, clusterOffsetsParameter, indicators, excisionPoints, treeModel, agLikelihood);
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "An operator that picks a new allocation of an item to a cluster under the Dirichlet process.";
}
public Class getReturnType() {
return TreeClusterSequentialSampling.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
AttributeRule.newDoubleRule(MCMCOperator.WEIGHT),
new ElementRule(VIRUSLOCATIONS, Parameter.class),
new ElementRule(MU, Parameter.class),
new ElementRule(CLUSTERLABELS, Parameter.class),
new ElementRule(K, Parameter.class),
new ElementRule(OFFSETS, Parameter.class),
// new ElementRule(LOCATION_DRIFT, Parameter.class), //no longer needed
//
new ElementRule(CLUSTER_OFFSETS, Parameter.class, "Parameter of cluster offsets of all virus"), // no longer REQUIRED
new ElementRule(INDICATORS, Parameter.class),
new ElementRule(EXCISION_POINTS, Parameter.class),
new ElementRule(TreeModel.class),
};
};
public int getStepCount() {
return 1;
}
}