/*
* ClusterSplitMergeOperator.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.antigenic;
import dr.inference.model.MatrixParameter;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.xml.*;
/**
* An operator to split or merge clusters.
*
* @author Andrew Rambaut
* @author Marc Suchard
* @version $Id: DirichletProcessGibbsOperator.java,v 1.16 2005/06/14 10:40:34 rambaut Exp $
*/
public class ClusterSplitMergeOperator extends SimpleMCMCOperator {
public final static boolean DEBUG = false;
public final static String CLUSTER_SPLIT_MERGE_OPERATOR = "clusterSplitMergeOperator";
private final int N; // the number of items
private int K; // the number of occupied clusters
private final Parameter allocationParameter;
private final MatrixParameter clusterLocations;
public ClusterSplitMergeOperator(Parameter allocationParameter,
MatrixParameter clusterLocations,
double weight) {
this.allocationParameter = allocationParameter;
this.clusterLocations = clusterLocations;
this.N = allocationParameter.getDimension();
setWeight(weight);
}
/**
* @return the parameter this operator acts on.
*/
public Parameter getParameter() {
return (Parameter) allocationParameter;
}
/**
* @return the Variable this operator acts on.
*/
public Variable getVariable() {
return allocationParameter;
}
/**
* change the parameter and return the hastings ratio.
*/
public final double doOperation() {
// get a copy of the allocations to work with...
int[] allocations = new int[allocationParameter.getDimension()];
// construct cluster occupancy vector excluding the selected item and count
// the unoccupied clusters.
int[] occupancy = new int[N];
int[] occupiedIndices = new int[N];
// for testing, set these to -1 to force out of bounds exception if
// used but not set
for (int i = 0; i < occupiedIndices.length; i++) {
occupiedIndices[i] = -1;
}
int K = 0; // k = number of unoccupied clusters
for (int i = 0; i < allocations.length; i++) {
allocations[i] = (int) allocationParameter.getParameterValue(i);
occupancy[allocations[i]] += 1;
if (occupancy[allocations[i]] == 1) { // first item in cluster
occupiedIndices[K] = allocations[i];
K++;
}
}
// Container for split/merge random variable (only 2 draws in 2D)
int paramDim = clusterLocations.getParameter(0).getDimension();
double[] splitDraw = new double[paramDim]; // Need to keep these for computing MHG ratio
double scale = 1.0; // TODO make tunable
double newClusterProb = 0.5; // TODO Make tunable
// always split when K = 1, always merge when K = N, otherwise 50:50
boolean doSplit = K == 1 || (K != N && MathUtils.nextBoolean());
if (doSplit) {
// Split operation
int cluster1;
do {
// pick an occupied cluster
cluster1 = occupiedIndices[MathUtils.nextInt(K)];
// For reversibility, merge step requires that both resulting clusters are occupied,
// so we should resample until condition is true
} while (occupancy[cluster1] == 0);
// find the first unoccupied cluster
int cluster2 = 0;
while (occupancy[cluster2] > 0) {
cluster2 ++;
}
int oldCount = occupancy[cluster1];
do {
occupancy[cluster1] = 0;
occupancy[cluster2] = 0;
for (int i = 0; i < allocations.length; i++) {
if (allocations[i] == cluster1 || allocations[i] == cluster2) {
boolean putInNewCluster = MathUtils.nextDouble() < newClusterProb;
if (putInNewCluster) {
allocations[i] = cluster2;
occupancy[cluster2]++;
} else {
allocations[i] = cluster1;
occupancy[cluster1]++;
}
}
}
} while (occupancy[cluster1] != 0 && occupancy[cluster2] != 0);
K++;
// set both clusters to a location based on the first cluster with some random jitter...
Parameter param1 = clusterLocations.getParameter(cluster1);
Parameter param2 = clusterLocations.getParameter(cluster2);
double[] loc = param1.getParameterValues();
for (int dim = 0; dim < param1.getDimension(); dim++) {
splitDraw[dim] = MathUtils.nextGaussian();
param1.setParameterValue(dim, loc[dim] + (splitDraw[dim] * scale));
param2.setParameterValue(dim, loc[dim] - (splitDraw[dim] * scale)); // Move in opposite direction
}
if (DEBUG) {
System.err.println("Split: " + (oldCount - occupancy[cluster1]) + " items from cluster " + cluster1 + " to create cluster " + cluster2);
}
} else {
// Merge operation
// pick 2 occupied clusters
int cluster1 = occupiedIndices[MathUtils.nextInt(K)];
int cluster2;
do {
cluster2 = occupiedIndices[MathUtils.nextInt(K)];
// resample until cluster1 != cluster2 to maintain reversibility, because split assumes they are different
} while (cluster1 == cluster2);
for (int i = 0; i < allocations.length; i++) {
if (allocations[i] == cluster2) {
allocations[i] = cluster1;
// keep occupancy up to date (remove if not need)
occupancy[cluster1] ++;
occupancy[cluster2] --;
}
}
K--;
// set the merged cluster to the mean location of the two original clusters
Parameter loc1 = clusterLocations.getParameter(cluster1);
Parameter loc2 = clusterLocations.getParameter(cluster2);
for (int dim = 0; dim < loc1.getDimension(); dim++) {
double average = (loc1.getParameterValue(dim) + loc2.getParameterValue(dim)) / 2.0;
splitDraw[dim] = (loc1.getParameterValue(dim) - average) / scale; // Record that the reverse step would need to draw
loc1.setParameterValue(dim, average);
// Consider loc2 as the extra dimensions for dimension-matching
// On second thought, maybe not a good idea
// loc2.setParameterValue(dim, splitDraw[dim]);
}
if (DEBUG) {
System.err.println("Merge: " + occupancy[cluster1] + "items into cluster " + cluster1 + " from " + cluster2);
}
}
// set the final allocations (only for those that have changed)
for (int i = 0; i < allocations.length; i++) {
int k = (int) allocationParameter.getParameterValue(i);
if (allocations[i] != k) {
allocationParameter.setParameterValue(i, allocations[i]);
}
}
// todo the Hastings ratio
return 0.0;
}
//MCMCOperator INTERFACE
public final String getOperatorName() {
return CLUSTER_SPLIT_MERGE_OPERATOR +"(" + allocationParameter.getId() + ")";
}
public final void optimize(double targetProb) {
throw new RuntimeException("This operator cannot be optimized!");
}
public boolean isOptimizing() {
return false;
}
public void setOptimizing(boolean opt) {
throw new RuntimeException("This operator cannot be optimized!");
}
public double getMinimumAcceptanceLevel() {
return 0.1;
}
public double getMaximumAcceptanceLevel() {
return 0.4;
}
public double getMinimumGoodAcceptanceLevel() {
return 0.20;
}
public double getMaximumGoodAcceptanceLevel() {
return 0.30;
}
public String getPerformanceSuggestion() {
if (Utils.getAcceptanceProbability(this) < getMinimumAcceptanceLevel()) {
return "";
} else if (Utils.getAcceptanceProbability(this) > getMaximumAcceptanceLevel()) {
return "";
} else {
return "";
}
}
public String toString() {
return getOperatorName();
}
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public final static String CHI = "chi";
public final static String LIKELIHOOD = "likelihood";
public String getParserName() {
return CLUSTER_SPLIT_MERGE_OPERATOR;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT);
Parameter allocationParameter = (Parameter) xo.getChild(Parameter.class);
MatrixParameter locationsParameter = (MatrixParameter) xo.getElementFirstChild("locations");
return new ClusterSplitMergeOperator(allocationParameter, locationsParameter, weight);
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "An operator that splits and merges clusters.";
}
public Class getReturnType() {
return ClusterSplitMergeOperator.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
AttributeRule.newDoubleRule(MCMCOperator.WEIGHT),
new ElementRule(Parameter.class),
new ElementRule("locations", new XMLSyntaxRule[] {
new ElementRule(MatrixParameter.class)
})
};
};
public int getStepCount() {
return 1;
}
}