/* * MultivariateNormalGibbsOperator.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.operators; import dr.inference.distribution.MultivariateDistributionLikelihood; import dr.inference.distribution.MultivariateNormalDistributionModel; import dr.inference.model.MatrixParameter; import dr.inference.model.Parameter; import dr.math.distributions.MultivariateNormalDistribution; import dr.math.matrixAlgebra.IllegalDimension; import dr.math.matrixAlgebra.Matrix; import dr.math.matrixAlgebra.Vector; import dr.util.Attribute; import java.util.List; /** @author Max Tolkoff */ public class MultivariateNormalGibbsOperator extends SimpleMCMCOperator implements GibbsOperator { private Matrix priorPrecision; private Vector priorMean; private MatrixParameter likelihoodPrecision; private Parameter likelihoodMean; private MultivariateDistributionLikelihood likelihood; private int dim; public static final String MVN_GIBBS="multivariateNormalGibbsOperator"; public MultivariateNormalGibbsOperator(MultivariateDistributionLikelihood likelihood, MultivariateDistributionLikelihood prior, Double weight) throws IllegalDimension { MultivariateNormalDistribution tempPrior=(MultivariateNormalDistribution) prior.getDistribution(); this.priorMean=new Vector(tempPrior.getMean()); this.priorPrecision=new Matrix(tempPrior.getScaleMatrix()); MultivariateNormalDistributionModel tempLikelihood=(MultivariateNormalDistributionModel) likelihood.getDistribution(); this.likelihoodMean=tempLikelihood.getMeanParameter(); this.likelihoodPrecision=tempLikelihood.getPrecisionMatrixParameter(); this.likelihood=likelihood; this.dim=likelihoodMean.getValues().length; // if(dataTemp.contains(MatrixParameter.class)) // {System.err.print("Well, at least you know it's there...\n");} // else{System.err.print("Nope, you screwed up\n");} setWeight(weight); } private void setParameterValue(Parameter set, double[] value){ set.setDimension(value.length); for(int i=0; i<value.length; i++) {set.setParameterValueQuietly(i,value[i]);} set.fireParameterChangedEvent(); } private double[] getMeanSum(){ double[] answer=new double[dim]; List<Attribute<double[]>> dataList = likelihood.getDataList(); for(Attribute<double[]> d: dataList){ for(int i=0; i<d.getAttributeValue().length; i++) { answer[i]+=d.getAttributeValue()[i]; } } /* for(int i=0; i<dim; i++){ System.err.print(answer[i]); System.err.print("\n");} */ return answer;} private Matrix getPrecision() throws IllegalDimension { Matrix currentPrecision=new Matrix(likelihoodPrecision.getParameterAsMatrix()); currentPrecision=currentPrecision.product(likelihood.getDataList().size()); /* for(int i=0; i<currentPrecision.columns(); i++){ for(int j=0; j<currentPrecision.rows(); j++){ System.err.print(currentPrecision.toComponents()[i][j]); System.err.print(" ");} System.err.print("\n"); } */ return priorPrecision.add(currentPrecision); } private Vector getMean() throws IllegalDimension { Vector meanSum=new Vector(getMeanSum()); Matrix workingPrecision=new Matrix(likelihoodPrecision.getParameterAsMatrix()); Vector meanPart=workingPrecision.product(meanSum); meanPart=meanPart.add(priorPrecision.product(priorMean)); Matrix varPart=getPrecision().inverse(); Vector answer=varPart.product(meanPart); /* for(int i=0; i<varPart.columns(); i++){ for(int j=0; j<varPart.rows(); j++){ System.err.print(varPart.toComponents()[i][j]);} System.err.print("\n"); } this.priorPrecision=new Matrix(prior.getDistribution().getScaleMatrix()); System.err.print(answer.toComponents()[0]); System.err.print("\n"); System.out.print(answer.toComponents()[0]); for(int i=0; i<answer.dimension(); i++){ System.err.print(answer.toComponents()[i]);} System.err.print("\n"); */ return answer; } // private Vector getDraws() throws IllegalDimension{ // double[] rUniform=new double[dim]; // for(int i=0; i<dim; i++) // {rUniform[i]=} // Vector draws=new Vector(MultivariateNormalDistribution.); // return draws; // } @Override public String getOperatorName() { return MVN_GIBBS; //To change body of implemented methods use File | Settings | File Templates. } @Override public String getPerformanceSuggestion() { return null; //To change body of implemented methods use File | Settings | File Templates. } @Override public double doOperation() { double[] draws=null; // try { // for(int i=0; i<getPrecision().columns(); i++){ // for(int j=0; j<getPrecision().rows(); j++){ // System.err.print(getPrecision().toComponents()[i][j]); // System.err.print(", ");}} //// System.err.print(" ");} //// System.err.print("\n"); } // } catch (IllegalDimension illegalDimension) { // illegalDimension.printStackTrace(); //To change body of catch statement use File | Settings | File Templates. // } try { draws=MultivariateNormalDistribution.nextMultivariateNormalPrecision(getMean().toComponents(), getPrecision().toComponents()); } catch (IllegalDimension illegalDimension) { illegalDimension.printStackTrace(); //To change body of catch statement use File | Settings | File Templates. } // for(int i=0; i<dim; i++){ // System.err.print(draws[i]); // System.err.print("\n");} setParameterValue(likelihoodMean, draws); return 0; //To change body of implemented methods use File | Settings | File Templates. } @Override public int getStepCount() { return 1; //To change body of implemented methods use File | Settings | File Templates. } }