/* * CompoundGradient.java * * Copyright (c) 2002-2017 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.hmc; import dr.inference.hmc.GradientWrtParameterProvider; import dr.inference.model.CompoundLikelihood; import dr.inference.model.CompoundParameter; import dr.inference.model.Likelihood; import dr.inference.model.Parameter; import java.util.ArrayList; import java.util.List; /** * @author Max Tolkoff * @author Marc A. Suchard */ public class CompoundGradient implements GradientWrtParameterProvider { private final int dimension; private final List<GradientWrtParameterProvider> derivativeList; private final Likelihood likelihood; private final Parameter parameter; public CompoundGradient(List<GradientWrtParameterProvider> derivativeList) { this.derivativeList = derivativeList; if (derivativeList.size() == 1) { likelihood = derivativeList.get(0).getLikelihood(); parameter = derivativeList.get(0).getParameter(); dimension = parameter.getDimension(); } else { List<Likelihood> likelihoodList = new ArrayList<Likelihood>(); CompoundParameter compoundParameter = new CompoundParameter("hmc"); int dim = 0; for (GradientWrtParameterProvider grad : derivativeList) { likelihoodList.add(grad.getLikelihood()); Parameter p = grad.getParameter(); compoundParameter.addParameter(p); dim += p.getDimension(); } likelihood = new CompoundLikelihood(likelihoodList); parameter = compoundParameter; dimension = dim; } } @Override public Likelihood getLikelihood() { return likelihood; } @Override public Parameter getParameter() { return parameter; } @Override public int getDimension() { return dimension; } // @Override // public void getGradientLogDensity(final double[] destination, final int offset) { // double[] grad = getGradientLogDensity(); // System.arraycopy(grad, 0, destination, offset, grad.length); // } @Override public double[] getGradientLogDensity() { double[] result = new double[dimension]; int offset = 0; for (GradientWrtParameterProvider grad : derivativeList) { double[] tmp = grad.getGradientLogDensity(); System.arraycopy(tmp, 0, result, offset, grad.getDimension()); offset += grad.getDimension(); } return result; } }