/* * RandomWalkIntegerOperator.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.operators; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.inferencexml.operators.RandomWalkIntegerOperatorParser; import dr.math.MathUtils; /** * A generic random walk operator for use with a multi-dimensional Integer parameters. * * @author Michael Defoin Platel * @version $Id: RandomWalkIntegerOperator.java$ */ public class RandomWalkIntegerOperator extends SimpleMCMCOperator { public RandomWalkIntegerOperator(Variable parameter, int windowSize, double weight) { this.parameter = parameter; this.windowSize = windowSize; setWeight(weight); } /** * @return the parameter this operator acts on. */ public Parameter getParameter() { return (Parameter) parameter; } public final int getWindowSize() { return windowSize; } /** * change the parameter and return the hastings ratio. */ double logq; public double doOperation() { logq = 0.0; // a random dimension to perturb int index = MathUtils.nextInt(parameter.getSize()); // use getSize(), which = getDimension() if (parameter instanceof Parameter) { int newValue = calculateNewValue(index); ((Parameter) parameter).setParameterValue(index, newValue); //System.out.println("newValue: "+newValue); } else if (parameter instanceof Variable) { // todo this code is improper if we are going to use Variable<Double> int newValue = calculateNewValue(index); ((Variable<Integer>) parameter).setValue(index, newValue); } return logq; } protected int calculateNewValue(int index) { // a random non zero integer around old value within windowSize * 2 int oldValue; int upper; int lower; if (parameter instanceof Parameter) { oldValue = (int) ((Parameter) parameter).getParameterValue(index); upper = (int) (double) ((Parameter) parameter).getBounds().getUpperLimit(index); lower = (int) (double) ((Parameter) parameter).getBounds().getLowerLimit(index); } else if (parameter instanceof Variable) { // todo this code is improper if we are going to use Variable<Double> oldValue = ((Variable<Integer>) parameter).getValue(index); upper = ((Variable<Integer>) parameter).getBounds().getUpperLimit(index); lower = ((Variable<Integer>) parameter).getBounds().getLowerLimit(index); } else { throw new RuntimeException("The parameter (" + parameter.getId() + ") uses invalid class!"); } if (upper == lower) return upper; int maxWindowSize = upper - lower; if(windowSize> maxWindowSize){ windowSize = maxWindowSize; System.err.println("The maximum window size should be smaller than the total number of possible integer values."); } int newValue; int roll = MathUtils.nextInt(2 * windowSize); // windowSize="1"; roll = {0, 1} if (roll >= windowSize) { // roll = 1 //roll - window is the positive step size int step = 1 + (roll - windowSize); newValue = oldValue + step; if (newValue > upper){ newValue = 2 * upper - newValue; //reflect down } } else { // roll = 0 newValue = oldValue - 1 - roll; if (newValue < lower){ newValue = 2 * lower - newValue; //reflect up } } //New and seemingly correct (accoding to the running MCMC with uniform prior) //calculation of the hastings ratio --CHW int newToOldCount = 0; int oldToNewCount = 0; if(newValue != oldValue){ oldToNewCount = oldToNewCount +1; newToOldCount = newToOldCount +1; } int temp = oldValue + windowSize; if(temp > upper){ if((2*upper - temp) <= newValue && newValue != upper){ oldToNewCount = oldToNewCount+1; } } temp = oldValue - windowSize; if(temp < lower){ if((2*lower - temp) >= newValue && newValue != lower){ oldToNewCount = oldToNewCount+1; } } temp = newValue + windowSize; if( temp > upper){ if((2*upper - temp) <= oldValue && oldValue != upper){ newToOldCount = newToOldCount+1; } } temp = newValue - windowSize; if( temp < lower){ if((2*lower - temp) >= oldValue && oldValue != lower){ newToOldCount = newToOldCount+1; } } logq = Math.log(newToOldCount)- Math.log(oldToNewCount); return newValue; } //MCMCOperator INTERFACE public String getOperatorName() { return "randomWalkInteger(" + parameter.getId() + ")"; } public double getTargetAcceptanceProbability() { return 0.234; } public double getMinimumAcceptanceLevel() { return 0.1; } public double getMaximumAcceptanceLevel() { return 0.4; } public double getMinimumGoodAcceptanceLevel() { return 0.20; } public double getMaximumGoodAcceptanceLevel() { return 0.30; } public final String getPerformanceSuggestion() { double prob = Utils.getAcceptanceProbability(this); double targetProb = getTargetAcceptanceProbability(); double maxDelta = 0; if (parameter instanceof Parameter) { maxDelta = ((Parameter) parameter).getParameterValue(0) * 2.0; } else if (parameter instanceof Variable) { maxDelta = ((Variable<Integer>) parameter).getValue(0) * 2.0; } long ws = Math.round(OperatorUtils.optimizeWindowSize(windowSize, maxDelta * 2.0, prob, targetProb)); if (prob < getMinimumGoodAcceptanceLevel()) { if(ws <= 1){ return ""; } return "Try decreasing windowSize to about " + ws; } else if (prob > getMaximumGoodAcceptanceLevel()) { return "Try increasing windowSize to about " + ws; } else return ""; } public String toString() { return RandomWalkIntegerOperatorParser.RANDOM_WALK_INTEGER_OPERATOR + "(" + parameter.getId() + ", " + windowSize + ", " + getWeight() + ")"; } //PRIVATE STUFF protected Variable parameter = null; protected int windowSize = 1; }