/* * DirichletSiteModel.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.oldevomodel.sitemodel; import dr.oldevomodel.substmodel.FrequencyModel; import dr.oldevomodel.substmodel.SubstitutionModel; import dr.inference.model.AbstractModel; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.xml.*; import java.util.logging.Logger; /** * DirichletSiteModel - A SiteModel that has a free rate parameter for each category. * These rate parameters are constrained to sum to 1. * * @author Benjamin Redelings */ public class DirichletSiteModel extends AbstractModel implements SiteModel { public static final String SUBSTITUTION_MODEL = "substitutionModel"; public static final String DIRICHLET_SITE_MODEL = "DirichletSiteModel"; public static final String CATEGORIES = "categories"; public static final String RATES = "rates"; /** * Constructor for gamma+invar distributed sites. Either shapeParameter or * invarParameter (or both) can be null to turn off that feature. */ public DirichletSiteModel(SubstitutionModel substitutionModel, int n_categories) { this(substitutionModel, new Parameter.Default(new double[n_categories])); } public DirichletSiteModel(SubstitutionModel substitutionModel, Parameter ratesParameter) { super(DIRICHLET_SITE_MODEL); // Set substitution model this.substitutionModel = substitutionModel; addModel(substitutionModel); // Allocate array for categories and proportions int n_categories = ratesParameter.getDimension(); categoryRates = new double[n_categories]; categoryProportions = new double[n_categories]; // We haven't calculated any rates yet. ratesKnown = false; // Initialize the rate parameter value for(int i=0;i<n_categories;i++) { ratesParameter.setParameterValue(i, 1.0/n_categories); categoryProportions[i] = 1.0/n_categories; } // Add the rate parameters to the model this.ratesParameter = ratesParameter; addVariable(this.ratesParameter); this.ratesParameter.addBounds(new Parameter.DefaultBounds(1.0, 0.0, ratesParameter.getDimension())); } public Parameter getRatesParameter() { return ratesParameter; } // public void setRatesParameter(Parameter parameter) { // removeVariable(ratesParameter); // ratesParameter = parameter; // addVariable(ratesParameter); // } // ***************************************************************** // Interface SiteModel // ***************************************************************** public boolean integrateAcrossCategories() { return true; } public int getCategoryCount() { return categoryRates.length; } public int getCategoryOfSite(int site) { throw new IllegalArgumentException("Integrating across categories"); } public double getRateForCategory(int category) { calculateCategoryRates(); return categoryRates[category]; } public double[] getCategoryRates() { calculateCategoryRates(); double[] rates = new double[categoryRates.length]; for (int i = 0; i < rates.length; i++) { rates[i] = categoryRates[i]; } return rates; } public void getTransitionProbabilities(double substitutions, double[] matrix) { substitutionModel.getTransitionProbabilities(substitutions, matrix); } /** * Get the expected proportion of sites in this category. * * @param category the category number * @return the proportion. */ public double getProportionForCategory(int category) { calculateCategoryRates(); return categoryProportions[category]; } /** * Get an array of the expected proportion of sites in this category. * * @return an array of the proportion. */ public double[] getCategoryProportions() { calculateCategoryRates(); return categoryProportions; } private double get_substitution_scale() { double scale=0; for(int i=0;i<categoryRates.length;i++) { // System.out.println(" get_substitution_scale: i = " + i); scale += categoryProportions[i]*ratesParameter.getParameterValue(i); } return scale; } /** * Calculate the category rates from the unscaled rate variables. */ private void calculateCategoryRates() { synchronized (this) { if (ratesKnown) return; } double scale = get_substitution_scale(); double temp = 1.0/scale; for(int i=0;i<categoryRates.length;i++) categoryRates[i] = ratesParameter.getParameterValue(i)*temp; ratesKnown = true; } /** * Get the frequencyModel for this SiteModel. * * @return the frequencyModel. */ public FrequencyModel getFrequencyModel() { return substitutionModel.getFrequencyModel(); } /** * Get the substitutionModel for this SiteModel. * * @return the substitutionModel. */ public SubstitutionModel getSubstitutionModel() { return substitutionModel; } // ***************************************************************** // Interface ModelComponent // ***************************************************************** protected void handleModelChangedEvent(Model model, Object object, int index) { // Substitution model has changed so fire model changed event listenerHelper.fireModelChanged(this, object, index); } protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { if (variable == ratesParameter) { ratesKnown = false; } else { // This should not happen } listenerHelper.fireModelChanged(this, variable, index); } protected void storeState() { } // no additional state needs storing protected void restoreState() { ratesKnown = false; } protected void acceptState() { } // no additional state needs accepting public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String[] getParserNames() { return new String[] { getParserName(), "beast_"+getParserName() }; } public String getParserName() { return DIRICHLET_SITE_MODEL; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { SubstitutionModel substitutionModel = (SubstitutionModel) xo.getElementFirstChild(SUBSTITUTION_MODEL); String msg = ""; Parameter ratesParameter = null; if (xo.hasChildNamed("rates")) { XMLObject cxo = (XMLObject) xo.getChild("rates"); ratesParameter = (Parameter) cxo.getChild(Parameter.class); msg += "\n with " + ratesParameter.getDimension() + " categories."; } if (msg.length() > 0) { Logger.getLogger("dr.evomodel").info("Creating site model: " + msg); } else { Logger.getLogger("dr.evomodel").info("Creating site model."); } return new DirichletSiteModel(substitutionModel, ratesParameter); } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "A SiteModel that has Dirichlet distributed rates across sites"; } public Class getReturnType() { return DirichletSiteModel.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new ElementRule(SUBSTITUTION_MODEL, new XMLSyntaxRule[]{ new ElementRule(SubstitutionModel.class) }), new ElementRule(RATES, new XMLSyntaxRule[]{ new ElementRule(Parameter.class) }) }; }; /** * the substitution model for these sites */ private SubstitutionModel substitutionModel = null; /** * the substitution rates: these sum to 1 */ private Parameter ratesParameter; private boolean ratesKnown; private double[] categoryRates; private double[] categoryProportions; }