/* * NewHamiltonianMonteCarloOperator.java * * Copyright (c) 2002-2017 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.operators.hmc; import dr.inference.hmc.GradientWrtParameterProvider; import dr.inference.model.Likelihood; import dr.inference.model.Parameter; import dr.inference.operators.AbstractCoercableOperator; import dr.inference.operators.CoercionMode; import dr.inference.operators.GeneralOperator; import dr.math.distributions.NormalDistribution; import dr.util.Transform; /** * @author Max Tolkoff * @author Marc A. Suchard */ public class HamiltonianMonteCarloOperator extends AbstractCoercableOperator { private final GradientWrtParameterProvider gradientProvider; private double stepSize; private final int nSteps; private final NormalDistribution drawDistribution; private final LeapFrogEngine leafFropEngine; public HamiltonianMonteCarloOperator(CoercionMode mode, double weight, GradientWrtParameterProvider gradientProvider, Parameter parameter, Transform transform, double stepSize, int nSteps, double drawVariance) { super(mode); setWeight(weight); setTargetAcceptanceProbability(0.8); // Stan default this.gradientProvider = gradientProvider; this.stepSize = stepSize; this.nSteps = nSteps; this.drawDistribution = new NormalDistribution(0, Math.sqrt(drawVariance)); this.leafFropEngine = (transform != null ? new LeapFrogEngine.WithTransform(parameter, transform) : new LeapFrogEngine.Default(parameter)); } @Override public String getPerformanceSuggestion() { return null; } @Override public String getOperatorName() { return "Vanilla HMC operator"; } private static double getScaledDotProduct(final double[] momentum, final double sigmaSquared) { final int dim = momentum.length; double total = 0.0; for (int i = 0; i < dim; i++) { total += momentum[i] * momentum[i]; } return total / (2 * sigmaSquared); } private static double[] drawInitialMomentum(final NormalDistribution distribution, final int dim) { double[] momentum = new double[dim]; for (int i = 0; i < dim; i++) { momentum[i] = (Double) distribution.nextRandom(); } return momentum; } @Override public double doOperation() { return leafFrog(); } protected double leafFrog() { final int dim = gradientProvider.getDimension(); final double sigmaSquared = drawDistribution.getSD() * drawDistribution.getSD(); double[] momentum = drawInitialMomentum(drawDistribution, dim); double[] position = leafFropEngine.getInitialPosition(); final double prop = getScaledDotProduct(momentum, sigmaSquared) + leafFropEngine.getParameterLogJacobian(); leafFropEngine.updateMomentum(position, momentum, gradientProvider.getGradientLogDensity(), stepSize / 2); for (int i = 0; i < nSteps; i++) { // Leap-frog leafFropEngine.updatePosition(position, momentum, stepSize, sigmaSquared); if (i < (nSteps - 1)) { leafFropEngine.updateMomentum(position, momentum, gradientProvider.getGradientLogDensity(), stepSize); } } // end of loop over steps leafFropEngine.updateMomentum(position, momentum, gradientProvider.getGradientLogDensity(), stepSize / 2); final double res = getScaledDotProduct(momentum, sigmaSquared) + leafFropEngine.getParameterLogJacobian(); return prop - res; } @Override public double getCoercableParameter() { return Math.log(stepSize); } @Override public void setCoercableParameter(double value) { stepSize = Math.exp(value); } @Override public double getRawParameter() { return stepSize; } interface LeapFrogEngine { double[] getInitialPosition(); double getParameterLogJacobian(); void updateMomentum(final double[] position, final double[] momentum, final double[] gradient, final double functionalStepSize); void updatePosition(final double[] position, final double[] momentum, final double functionalStepSize, final double sigmaSquared); class Default implements LeapFrogEngine { final protected Parameter parameter; protected Default(Parameter parameter) { this.parameter = parameter; } @Override public double[] getInitialPosition() { return parameter.getParameterValues(); } @Override public double getParameterLogJacobian() { return 0; } @Override public void updateMomentum(double[] position, double[] momentum, double[] gradient, double functionalStepSize) { final int dim = momentum.length; for (int i = 0; i < dim; ++i) { momentum[i] += functionalStepSize * gradient[i]; } } @Override public void updatePosition(double[] position, double[] momentum, double functionalStepSize, double sigmaSquared) { final int dim = momentum.length; for (int j = 0; j < dim; j++) { position[j] += functionalStepSize * momentum[j] / sigmaSquared; } setParameter(position); // Write back into BEAST model } protected void setParameter(double[] position) { final int dim = position.length; for (int j = 0; j < dim; ++j) { parameter.setParameterValueQuietly(j, position[j]); } parameter.fireParameterChangedEvent(); } } class WithTransform extends Default { final private Transform transform; double[] unTransformedPosition; protected WithTransform(Parameter parameter, Transform transform) { super(parameter); this.transform = transform; } @Override public double getParameterLogJacobian() { return transform.getLogJacobian(unTransformedPosition, 0, unTransformedPosition.length); } @Override public double[] getInitialPosition() { unTransformedPosition = super.getInitialPosition(); return transform.transform(unTransformedPosition, 0, unTransformedPosition.length); } @Override public void updateMomentum(double[] position, double[] momentum, double[] gradient, double functionalStepSize) { gradient = transform.updateGradientLogDensity(gradient, unTransformedPosition, 0, unTransformedPosition.length); super.updateMomentum(position, momentum, gradient, functionalStepSize); } @Override protected void setParameter(double[] position) { unTransformedPosition = transform.inverse(position, 0, position.length); super.setParameter(unTransformedPosition); } } } }