/* * MomentDistributionModel.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.distribution; import dr.inference.model.*; import dr.inference.model.Parameter; import dr.inferencexml.distribution.MomentDistributionModelParser; import dr.math.distributions.RandomGenerator; //@author Max Tolkoff public class MomentDistributionModel extends AbstractModelLikelihood implements ParametricMultivariateDistributionModel, RandomGenerator { public MomentDistributionModel(Parameter mean, Parameter precision, Parameter cutoff, Parameter data) { super(MomentDistributionModelParser.MOMENT_DISTRIBUTION_MODEL); this.mean=mean; this.precision=precision; // this.mean = new DuplicatedParameter(mean); // this.mean.addDuplicationParameter(new Parameter.Default(cutoff.getDimension())); // DuplicatedParameter precTemp= new DuplicatedParameter(precision); // precTemp.addDuplicationParameter(new Parameter.Default(cutoff.getDimension())); // this.precision=new DiagonalMatrix(precTemp); addVariable(mean); mean.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 1)); addVariable(precision); // precision.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); this.cutoff=cutoff; if(cutoff!=null){ addVariable(cutoff); int boundsDim; if(cutoff.getBounds()!=null) boundsDim = cutoff.getBounds().getBoundsDimension(); else boundsDim = cutoff.getDimension(); cutoff.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, boundsDim));} addVariable(data); this.data=data; untruncated=new NormalDistributionModel(mean, precision, true); sumKnown=false; untruncatedKnown=false; } private final Parameter mean; private final Parameter precision; // private final DuplicatedParameter mean; // private final DiagonalMatrix precision; private final Parameter cutoff; private NormalDistributionModel untruncated; private double sum; private boolean sumKnown; private boolean storedSumKnown; private double storedSum; private boolean untruncatedKnown; private boolean storedUntruncatedKnown; private NormalDistributionModel storedUntruncated; private Parameter data; public double logPdf(Parameter data) { // untruncatedKnown=false; // sumKnown=false; checkDistribution(); if(sumKnown) return sum; else { sum=0; } if(cutoff!=null){ if(data.getDimension()!=cutoff.getDimension()){ throw new RuntimeException("Incorrect number of cutoffs"); } for (int i = 0; i <data.getDimension() ; i++) { if (Math.sqrt(cutoff.getParameterValue(i)) - .05 > Math.abs(data.getParameterValue(i)) && data.getParameterValue(i)!=0){ return Double.NEGATIVE_INFINITY; } else if(data.getParameterValue(i)==0) sum+=-1000-Math.log(precision.getParameterValue(0)); else sum+=untruncated.logPdf(data.getParameterValue(i));//(2*untruncated.logPdf(cutoff.getParameterValue(i))); } } else{ for (int i = 0; i <data.getDimension() ; i++) { sum+= untruncated.logPdf(data.getParameterValue(i)) + 2 * StrictMath.log(Math.abs(data.getParameterValue(i))) + StrictMath.log(precision.getParameterValue(0)); } } sumKnown=true; return sum; } @Override public double logPdf(double[] x) { return 0; } public Parameter getCutoff(){return cutoff;} @Override public double[][] getScaleMatrix() { double[][] temp=new double[1][1]; temp[0][0]=precision.getParameterValue(0); return temp; // return precision.getParameterAsMatrix(); } @Override public double[] getMean() { return mean.getParameterValues(); } @Override public String getType() { return "Moment Distribution Model"; } @Override protected void handleModelChangedEvent(Model model, Object object, int index) { } @Override protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { sumKnown=false; if(variable==mean || variable==precision) {untruncatedKnown=false;} } @Override protected void storeState() { storedSumKnown=sumKnown; storedSum=sum; storedUntruncated=untruncated; storedUntruncatedKnown=untruncatedKnown; } @Override protected void restoreState() { sumKnown=storedSumKnown; sum=storedSum; untruncated=storedUntruncated; untruncatedKnown=storedUntruncatedKnown; } @Override protected void acceptState() { } private NormalDistributionModel createNewDistribution() { return new NormalDistributionModel(mean, precision, true); // return new NormalDistributionModel(new Parameter.Default(mean.getParameterValue(0)), new Parameter.Default(precision.getParameterValue(0)), true); } private void checkDistribution() { if (!untruncatedKnown) { untruncated = createNewDistribution(); untruncatedKnown = true; } } @Override public double[] nextRandom() { return new double[0]; } @Override public double logPdf(Object x) { if(x instanceof Parameter) return logPdf((Parameter) x); else return 0; } @Override public Model getModel() { return this; } @Override public double getLogLikelihood() { return logPdf(data); } @Override public void makeDirty() { sumKnown=false; untruncatedKnown=false; } // ***************************************************************** // Interface DensityModel // ***************************************************************** @Override public Variable<Double> getLocationVariable() { throw new UnsupportedOperationException("Not implemented"); } }