/* * HamiltonianMonteCarloOperatorParser.java * * Copyright (c) 2002-2017 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inferencexml.operators.hmc; import dr.inference.hmc.GradientWrtParameterProvider; import dr.inference.model.Parameter; import dr.inference.operators.CoercionMode; import dr.inference.operators.hmc.HamiltonianMonteCarloOperator; import dr.inference.operators.MCMCOperator; import dr.inference.operators.hmc.NoUTurnOperator; import dr.util.Transform; import dr.xml.*; import java.util.List; /** * @author Max Tolkoff * @author Marc A. Suchard */ public class HamiltonianMonteCarloOperatorParser extends AbstractXMLObjectParser { public final static String HMC_OPERATOR = "HamiltonianMonteCarloOperator"; public static final String HMC_OPERATOR2 = "hamiltonianMonteCarloOperator"; public final static String N_STEPS = "nSteps"; public final static String STEP_SIZE = "stepSize"; public final static String DRAW_VARIANCE = "drawVariance"; public static final String MODE = "mode"; @Override public String getParserName() { return HMC_OPERATOR; } @Override public String[] getParserNames() { return new String[] { HMC_OPERATOR, HMC_OPERATOR2 }; } @Override public Object parseXMLObject(XMLObject xo) throws XMLParseException { double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT); int nSteps = xo.getIntegerAttribute(N_STEPS); double stepSize = xo.getDoubleAttribute(STEP_SIZE); double drawVariance = xo.getDoubleAttribute(DRAW_VARIANCE); int mode = xo.getAttribute(MODE, 0); GradientWrtParameterProvider derivative = (GradientWrtParameterProvider) xo.getChild(GradientWrtParameterProvider.class); Parameter parameter = (Parameter) xo.getChild(Parameter.class); Transform transform = null; if (parameter == null) { Transform.Collection collection = (Transform.Collection) xo.getChild(Transform.Collection.class); parameter = collection.getParameter(); transform = collection; } if (derivative.getDimension() != parameter.getDimension()) { throw new XMLParseException("Gradient (" + derivative.getDimension() + ") must be the same dimensions as the parameter (" + parameter.getDimension() + ")"); } if (mode == 0) { return new HamiltonianMonteCarloOperator(CoercionMode.DEFAULT, weight, derivative, parameter, transform, stepSize, nSteps, drawVariance); } else { return new NoUTurnOperator(CoercionMode.DEFAULT, weight, derivative, parameter, stepSize, nSteps, drawVariance); } } @Override public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newDoubleRule(MCMCOperator.WEIGHT), AttributeRule.newIntegerRule(N_STEPS), AttributeRule.newDoubleRule(STEP_SIZE), AttributeRule.newDoubleRule(DRAW_VARIANCE), AttributeRule.newIntegerRule(MODE, true), new XORRule( new ElementRule(Parameter.class), new ElementRule(Transform.Collection.class) ), new ElementRule(GradientWrtParameterProvider.class), }; @Override public String getParserDescription() { return "Returns a Hamiltonian Monte Carlo transition kernel"; } @Override public Class getReturnType() { return HamiltonianMonteCarloOperator.class; } }