/* * DirichletProcessOperator.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.branchmodel.lineagespecific; import org.apache.commons.math.MathException; import dr.inference.model.CompoundLikelihood; import dr.inference.model.Likelihood; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.inference.operators.GibbsOperator; import dr.inference.operators.SimpleMCMCOperator; import dr.math.MathUtils; @SuppressWarnings("serial") public class DirichletProcessOperator extends SimpleMCMCOperator implements GibbsOperator { private static final boolean DEBUG = false; private DirichletProcessPrior dpp; private int realizationCount; private int uniqueRealizationCount; private double intensity; private int mhSteps; private Parameter categoriesParameter; private CountableRealizationsParameter allParameters; private Parameter uniqueParameters; private CompoundLikelihood likelihood; public DirichletProcessOperator(DirichletProcessPrior dpp, // Parameter categoriesParameter, // Parameter uniqueParameters, // CountableRealizationsParameter allParameters, Likelihood likelihood, // int mhSteps, // double weight// ) { this.dpp = dpp; this.intensity = dpp.getGamma(); this.uniqueRealizationCount = dpp.getCategoryCount(); this.realizationCount = categoriesParameter.getDimension(); this.categoriesParameter = categoriesParameter; this.allParameters = allParameters; this.uniqueParameters = uniqueParameters; this.likelihood = (CompoundLikelihood) likelihood; // this.likelihood = likelihood; this.mhSteps = mhSteps; setWeight(weight); }// END: Constructor public Parameter getParameter() { return categoriesParameter; }// END: getParameter public Variable getVariable() { return categoriesParameter; }// END: getVariable @Override public double doOperation() { try { // doOperate(); doOp(); } catch (MathException e) { e.printStackTrace(); }// END: try-catch block return 0.0; }// END: doOperation private void doOp() throws MathException { for (int index = 0; index < realizationCount; index++) { int[] occupancy = new int[uniqueRealizationCount]; for (int i = 0; i < realizationCount; i++) { if (i != index) { int j = (int) categoriesParameter.getParameterValue(i); occupancy[j]++; }// END: i check }// END: i loop double[] existingValues = new double[uniqueRealizationCount]; int counter = 0; int singletonIndex = -1; for(int i = 0; i < uniqueRealizationCount;i++){ if(occupancy[i] > 0) { occupancy[counter] = occupancy[i]; existingValues[counter++] = dpp.getUniqueParameter(i) .getParameterValue(0); } else { singletonIndex = i; }//END: occupancy check }//END: i loop // Propose new value(s) double[] baseProposals = new double[realizationCount]; for (int i = 0; i < baseProposals.length; i++) { baseProposals[i] = dpp.baseModel.nextRandom()[0]; } // If a singleton if(singletonIndex > -1) { baseProposals[0] = uniqueParameters.getParameterValue(singletonIndex); } double[] logClusterProbs = new double[uniqueRealizationCount]; // draw existing int i; for(i = 0; i < counter; i++) { logClusterProbs[i] = Math.log(occupancy[i] / (realizationCount - 1 + intensity)); double value = allParameters.getParameterValue(index); double candidate = existingValues[i]; allParameters.setParameterValue(index, candidate); likelihood.makeDirty(); logClusterProbs[i] = logClusterProbs[i] + likelihood.getLikelihood(index) .getLogLikelihood(); // logClusterProbs[i] = logClusterProbs[i] + likelihood .getLogLikelihood(); // System.out.println(likelihood.getLikelihood(index) .getLogLikelihood() + " " + likelihood .getLogLikelihood()); allParameters.setParameterValue(index, value); likelihood.makeDirty(); } // draw new for(; i < logClusterProbs.length; i++){ logClusterProbs[i] = Math.log((intensity) / (realizationCount - 1 + intensity)); // logClusterProbs[i] = Math.log(intensity / uniqueRealizationCount / (realizationCount - 1 + intensity)); double value = allParameters.getParameterValue(index); double candidate = baseProposals[i - counter]; allParameters.setParameterValue(index, candidate); likelihood.makeDirty(); logClusterProbs[i] = logClusterProbs[i] + likelihood.getLikelihood(index).getLogLikelihood(); // logClusterProbs[i] = logClusterProbs[i] + likelihood.getLogLikelihood(); // System.out.println(likelihood.getLikelihood(index) .getLogLikelihood() + " " + likelihood .getLogLikelihood()); allParameters.setParameterValue(index, value); likelihood.makeDirty(); } double smallestVal = logClusterProbs[0]; for(i = 1; i < uniqueRealizationCount; i++){ if(smallestVal > logClusterProbs[i]) { smallestVal = logClusterProbs[i]; } } double[] clusterProbs = new double[uniqueRealizationCount]; for(i = 0; i < clusterProbs.length;i++) { clusterProbs[i] = Math.exp(logClusterProbs[i]-smallestVal); } // dr.app.bss.Utils.printArray(clusterProbs); // System.exit(-1); // sample int sampledCluster = MathUtils.randomChoicePDF(clusterProbs); categoriesParameter.setParameterValue(index, sampledCluster); }//END: index loop }//END: doOp private void doOperate() throws MathException { // int index = 0; for (int index = 0; index < realizationCount; index++) { int[] occupancy = new int[uniqueRealizationCount]; for (int i = 0; i < realizationCount; i++) { if (i != index) { int j = (int) categoriesParameter.getParameterValue(i); occupancy[j]++; }// END: i check }// END: i loop if (DEBUG) { System.out.println("N[-index]: "); dr.app.bss.Utils.printArray(occupancy); } Likelihood clusterLikelihood = (Likelihood) likelihood.getLikelihood(index); // Likelihood clusterLikelihood = likelihood; int category = (int) categoriesParameter.getParameterValue(index); double value = uniqueParameters.getParameterValue(category); double[] clusterProbs = new double[uniqueRealizationCount]; for (int i = 0; i < uniqueRealizationCount; i++) { double logprob = 0; if (occupancy[i] == 0) {// draw new // draw from base model, evaluate at likelihood double candidate = dpp.baseModel.nextRandom()[0]; uniqueParameters.setParameterValue(category, candidate); double loglike = clusterLikelihood.getLogLikelihood(); uniqueParameters.setParameterValue(category, value); logprob = Math.log((intensity) / (realizationCount - 1 + intensity)) + loglike; } else {// draw existing // likelihood for component x_index double candidate = dpp.getUniqueParameter(i) .getParameterValue(0); uniqueParameters.setParameterValue(category, candidate); double loglike = clusterLikelihood.getLogLikelihood(); uniqueParameters.setParameterValue(category, value); logprob = Math.log(occupancy[i]) / (realizationCount - 1 + intensity) + loglike; }// END: occupancy check clusterProbs[i] = logprob; }// END: i loop dr.app.bss.Utils.exponentiate(clusterProbs); if (DEBUG) { System.out.println("P(z[index] | z[-index]): "); dr.app.bss.Utils.printArray(clusterProbs); } // sample int sampledCluster = MathUtils.randomChoicePDF(clusterProbs); categoriesParameter.setParameterValue(index, sampledCluster); if (DEBUG) { System.out .println("sampled category: " + sampledCluster + "\n"); } }// END: index loop }// END: doOperate @Override public String getOperatorName() { return DirichletProcessOperatorParser.DIRICHLET_PROCESS_OPERATOR; } @Override public String getPerformanceSuggestion() { return null; }// END: getPerformanceSuggestion @Override public int getStepCount() { return realizationCount; }// END: getStepCount }// END: class