/* * SimpleOperatorSchedule.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.operators; import dr.inference.loggers.LogColumn; import dr.inference.loggers.Loggable; import dr.inference.loggers.NumberColumn; import dr.math.MathUtils; import java.util.ArrayList; import java.util.List; import java.util.Vector; import java.util.logging.Logger; /** * This class implements a simple operator schedule. * * @author Alexei Drummond * @version $Id: SimpleOperatorSchedule.java,v 1.5 2005/06/14 10:40:34 rambaut Exp $ */ public class SimpleOperatorSchedule implements OperatorSchedule, Loggable { private final List<MCMCOperator> operators = new ArrayList<MCMCOperator>(); private final List<Integer> availableOperators = new ArrayList<Integer>(); private double totalWeight = 0; private int current = 0; private boolean sequential = false; private OptimizationTransform optimizationSchedule = OptimizationTransform.DEFAULT; int operatorUseThreshold = Integer.MAX_VALUE; // operator use threshold over which an operator may get turned off if ... double operatorAcceptanceThreshold = 0.0; // acceptance rate threshold under which an operator gets turned off public SimpleOperatorSchedule() { } public SimpleOperatorSchedule(int operatorUseThreshold, double operatorAcceptanceThreshold) { this.operatorUseThreshold = operatorUseThreshold; this.operatorAcceptanceThreshold = operatorAcceptanceThreshold; } public void addOperators(List<MCMCOperator> operators) { for (MCMCOperator operator : operators) { this.operators.add(operator); this.availableOperators.add(this.operators.size() - 1); } totalWeight = calculateTotalWeight(); } public void operatorsHasBeenUpdated() { totalWeight = calculateTotalWeight(); } public void addOperator(MCMCOperator op) { operators.add(op); availableOperators.add(operators.size() - 1); totalWeight = calculateTotalWeight(); } private double getWeight(int index) { return operators.get(availableOperators.get(index)).getWeight(); } private double calculateTotalWeight() { double totalWeight = 0.0; for (int i : availableOperators) { totalWeight += operators.get(i).getWeight(); } return totalWeight; } public int getNextOperatorIndex() { checkOperatorAcceptanceRates(); if (sequential) { int index = getWeightedOperatorIndex(current); current += 1; if (current >= totalWeight) { current = 0; } return index; } final double v = MathUtils.nextDouble(); //System.err.println("v=" + v); return getWeightedOperatorIndex(v * totalWeight); } public void setSequential(boolean sequential) { this.sequential = sequential; } private int getWeightedOperatorIndex(double q) { int index = 0; double weight = getWeight(index); while (weight <= q) { index += 1; weight += getWeight(index); } return index; } public MCMCOperator getOperator(int index) { return operators.get(availableOperators.get(index)); } public int getOperatorCount() { return availableOperators.size(); } private void checkOperatorAcceptanceRates() { List<Integer> toRemove = new ArrayList<Integer>(); for (int i : availableOperators) { MCMCOperator op = operators.get(i); if (op.getCount() > operatorUseThreshold) { double acceptanceRate = ((double)op.getAcceptCount()) / op.getCount(); if (acceptanceRate < operatorAcceptanceThreshold) { toRemove.add(i); Logger.getLogger("dr.app.beast").info("Operator " + op.getOperatorName() + " turned off with an acceptance rate of " + acceptanceRate + ", after " + op.getCount() + " tries."); } } } if (!toRemove.isEmpty()) { availableOperators.removeAll(toRemove); totalWeight = calculateTotalWeight(); } } public double getOptimizationTransform(double d) { switch( optimizationSchedule ) { case DEFAULT: case LOG: return Math.log(d); case SQRT: return Math.sqrt(d); case LINEAR: return d; default: throw new UnsupportedOperationException("Unknown enum value"); } } public void setOptimizationSchedule(OptimizationTransform optimizationSchedule) { this.optimizationSchedule = optimizationSchedule; } public long getMinimumAcceptAndRejectCount() { long minCount = Long.MAX_VALUE; for( MCMCOperator op : operators ) { if( op.getAcceptCount() < minCount || op.getRejectCount() < minCount ) { minCount = op.getCount(); } } return minCount; } // ************************************************************** // Loggable IMPLEMENTATION // ************************************************************** /** * @return the log columns. */ public LogColumn[] getColumns() { List<LogColumn> columnList = new ArrayList<LogColumn>(); for (int i = 0; i < getOperatorCount(); i++) { MCMCOperator op = getOperator(i); columnList.add(new OperatorAcceptanceColumn(op.getOperatorName(), op)); if (op instanceof CoercableMCMCOperator) { columnList.add(new OperatorSizeColumn(op.getOperatorName() + "_size", (CoercableMCMCOperator)op)); } } LogColumn[] columns = columnList.toArray(new LogColumn[columnList.size()]); return columns; } private class OperatorAcceptanceColumn extends NumberColumn { private final MCMCOperator op; public OperatorAcceptanceColumn(String label, MCMCOperator op) { super(label); this.op = op; } public double getDoubleValue() { return MCMCOperator.Utils.getAcceptanceProbability(op); } } private class OperatorSizeColumn extends NumberColumn { private final CoercableMCMCOperator op; public OperatorSizeColumn(String label, CoercableMCMCOperator op) { super(label); this.op = op; } public double getDoubleValue() { return op.getRawParameter(); } } }