/* * DirichletProcessPriorLogger.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.branchmodel.lineagespecific; import java.util.ArrayList; import java.util.List; import dr.app.bss.Utils; import dr.inference.distribution.ParametricMultivariateDistributionModel; import dr.inference.loggers.LogColumn; import dr.inference.loggers.Loggable; import dr.inference.loggers.NumberColumn; import dr.inference.model.CompoundParameter; import dr.inference.model.Parameter; import dr.math.distributions.NormalDistribution; public class DirichletProcessPriorLogger implements Loggable { // private ParametricMultivariateDistributionModel baseModel; private Parameter precisionParameter; private Parameter categoriesParameter; private CompoundParameter uniquelyRealizedParameters; private int uniqueRealizationCount; private int realizationCount; private double[] categoryProbabilities; private int newCategoryIndex; private double meanForCategory; private double newX; public DirichletProcessPriorLogger( Parameter precisionParameter, // Parameter categoriesParameter, // CompoundParameter uniquelyRealizedParameters // ) { this.precisionParameter = precisionParameter; this.uniquelyRealizedParameters = uniquelyRealizedParameters; this.categoriesParameter = categoriesParameter; this.uniqueRealizationCount = uniquelyRealizedParameters.getDimension(); this.realizationCount = categoriesParameter.getDimension(); }// END: Constructor private double[] getCategoryProbs() { double[] probs = new double[uniqueRealizationCount]; for (int i = 0; i < realizationCount; i++) { probs[(int) categoriesParameter.getParameterValue(i)]++; }// END: N loop // Utils.printArray(probs); for (int i = 0; i < uniqueRealizationCount; i++) { probs[i] = probs[i] / realizationCount; }// END: categoryCount loop // probs = new double[]{0.10, 0.10, 0.10, 0.10, 0.10, // 0.10, 0.10, 0.10, 0.10, 0.10}; return probs; }// END: getCategoryProbs @Override public LogColumn[] getColumns() { // this.categoryProbabilities = getCategoryProbs(); List<LogColumn> columns = new ArrayList<LogColumn>(); columns.add(new NewLogger("x.new")); columns.add(new NewCategoryLogger("category.new")); columns.add(new NewMeanLogger("mean.new")); for (int i = 0; i < uniquelyRealizedParameters.getDimension(); i++) { columns.add(new ProbabilitiesLogger("pi.", i)); } LogColumn[] rtnColumns = new LogColumn[columns.size()]; return columns.toArray(rtnColumns); }// END: getColumns private void getNew() { this.categoryProbabilities = getCategoryProbs(); this.newCategoryIndex = Utils.sample(categoryProbabilities); this.meanForCategory = uniquelyRealizedParameters .getParameterValue(newCategoryIndex); //TODO: generalize double sd = precisionParameter.getParameterValue(0); // System.out.println("FUBAR:" + sd); NormalDistribution nd = new NormalDistribution(meanForCategory, sd); this.newX = (Double) nd.nextRandom(); } private class NewLogger extends NumberColumn { public NewLogger(String label) { super(label); } @Override public double getDoubleValue() { getNew(); return newX; }// END: getDoubleValue }// END: NewLogger class private class NewMeanLogger extends NumberColumn { public NewMeanLogger(String label) { super(label); } @Override public double getDoubleValue() { return meanForCategory; }// END: getDoubleValue }// END: NewCategoryLogger class private class ProbabilitiesLogger extends NumberColumn { private int i; public ProbabilitiesLogger(String label, int i) { super(label + i); this.i = i; } @Override public double getDoubleValue() { return categoryProbabilities[i]; }// END: getDoubleValue }// END: NewCategoryLogger class private class NewCategoryLogger extends NumberColumn { public NewCategoryLogger(String label) { super(label); } @Override public double getDoubleValue() { return newCategoryIndex; }// END: getDoubleValue }// END: NewCategoryLogger class }// END: class