/*
* DirichletProcessGibbsOperator.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.antigenic;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.xml.*;
/**
* A Gibbs operator for allocation of items to clusters under a Dirichlet process.
*
* @author Andrew Rambaut
* @author Marc Suchard
* @version $Id: DirichletProcessGibbsOperator.java,v 1.16 2005/06/14 10:40:34 rambaut Exp $
*/
public class DirichletProcessGibbsOperator extends SimpleMCMCOperator implements GibbsOperator {
public final static String DIRICHLET_PROCESS_OPERATOR = "dirichletProcessOperator";
private final int N;
private final int K;
private final Parameter chiParameter;
private final Likelihood modelLikelihood;
public DirichletProcessGibbsOperator(Parameter clusteringParameter,
Parameter chiParameter,
Likelihood modelLikelihood,
double weight) {
this.clusteringParameter = clusteringParameter;
this.N = clusteringParameter.getDimension();
this.chiParameter = chiParameter;
this.modelLikelihood = modelLikelihood;
this.K = this.N; // TODO number of potential clusters should be much less than N
setWeight(weight);
}
/**
* @return the parameter this operator acts on.
*/
public Parameter getParameter() {
return (Parameter) clusteringParameter;
}
/**
* @return the Variable this operator acts on.
*/
public Variable getVariable() {
return clusteringParameter;
}
/**
* change the parameter and return the hastings ratio.
*/
public final double doOperation() {
int index = MathUtils.nextInt(clusteringParameter.getDimension());
int[] occupancy = new int[N];
// construct cluster occupancy vector excluding the selected item and count
// the unoccupied clusters.
int X = K; // X = number of unoccupied clusters
for (int i = 0; i < clusteringParameter.getDimension(); i++) {
int j = (int) clusteringParameter.getParameterValue(i);
if (i != index) {
occupancy[j] += 1;
if (occupancy[j] == 1) { // first item in cluster
X -= 1; // one fewer unoccupied
}
}
}
double chi = chiParameter.getParameterValue(0);
double p1 = chi / ((N - 1 + chi) * X);
double[] P = new double[K];
for (int i = 0; i < K; i++) {
double p;
if (occupancy[i] == 0) {
p = p1;
} else {
p = occupancy[i] / (N - 1 + chi);
}
P[i] = Math.log(p); // Store in log-scale for addition with conditionalLogLikelihood
}
if (modelLikelihood != null) {
for (int k = 0; k < K; ++k) {
clusteringParameter.setParameterValue(index, k);
P[k] += modelLikelihood.getLogLikelihood();;
}
}
this.rescale(P); // Improve numerical stability
this.exp(P); // Transform back to probability-scale
int k = MathUtils.randomChoicePDF(P);
((Parameter) clusteringParameter).setParameterValue(index, k);
return 0.0;
}
private void exp(double[] logX) {
for (int i = 0; i < logX.length; ++i) {
logX[i] = Math.exp(logX[i]);
}
}
private void rescale(double[] logX) {
double max = this.max(logX);
for (int i = 0; i < logX.length; ++i) {
logX[i] -= max;
}
}
private double max(double[] x) {
double max = x[0];
for (double xi : x) {
if (xi > max) {
max = xi;
}
}
return max;
}
//MCMCOperator INTERFACE
public final String getOperatorName() {
return DIRICHLET_PROCESS_OPERATOR+"(" + clusteringParameter.getId() + "|" + chiParameter.getId() + ")";
}
public final void optimize(double targetProb) {
throw new RuntimeException("This operator cannot be optimized!");
}
public boolean isOptimizing() {
return false;
}
public void setOptimizing(boolean opt) {
throw new RuntimeException("This operator cannot be optimized!");
}
public double getMinimumAcceptanceLevel() {
return 0.1;
}
public double getMaximumAcceptanceLevel() {
return 0.4;
}
public double getMinimumGoodAcceptanceLevel() {
return 0.20;
}
public double getMaximumGoodAcceptanceLevel() {
return 0.30;
}
public String getPerformanceSuggestion() {
if (Utils.getAcceptanceProbability(this) < getMinimumAcceptanceLevel()) {
return "";
} else if (Utils.getAcceptanceProbability(this) > getMaximumAcceptanceLevel()) {
return "";
} else {
return "";
}
}
public String toString() {
return DIRICHLET_PROCESS_OPERATOR + "(" + clusteringParameter.getId() + ")";
}
//PRIVATE STUFF
private Parameter clusteringParameter = null;
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public final static String CHI = "chi";
public final static String LIKELIHOOD = "likelihood";
public String getParserName() {
return DIRICHLET_PROCESS_OPERATOR;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT);
Parameter clusteringParameter = (Parameter) xo.getChild(Parameter.class);
Parameter chiParameter = (Parameter)xo.getElementFirstChild(CHI);
Likelihood likelihood = null;
if (xo.hasChildNamed(LIKELIHOOD)) {
likelihood = (Likelihood)xo.getElementFirstChild(LIKELIHOOD);
}
return new DirichletProcessGibbsOperator(clusteringParameter, chiParameter, likelihood, weight);
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "An operator that picks a new allocation of an item to a cluster under the Dirichlet process.";
}
public Class getReturnType() {
return DirichletProcessGibbsOperator.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
AttributeRule.newDoubleRule(MCMCOperator.WEIGHT),
new ElementRule(CHI, new XMLSyntaxRule[] {
new ElementRule(Parameter.class),
}),
new ElementRule(LIKELIHOOD, new XMLSyntaxRule[] {
new ElementRule(Likelihood.class),
}, true),
new ElementRule(Parameter.class)
};
};
public int getStepCount() {
return 1;
}
}