/*
* GammaSiteBMA.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.oldevomodel.sitemodel;
import dr.oldevomodel.substmodel.SubstitutionModel;
import dr.oldevomodel.substmodel.FrequencyModel;
import dr.inference.model.*;
import dr.math.distributions.GammaDistribution;
/**
* @author Chieh-Hsi Wu
*
* BSSVS on Gamma site model
*
*/
public class GammaSiteBMA extends AbstractModel implements SiteModel {
private SubstitutionModel substitutionModel = null;
private Parameter muParameter =null;
private Variable<Integer> modelChoose = null;
private Variable<Double> logShape = null;
private Variable<Double> logitInvar = null;
private int categoryCount = -1;
private double[] categoryRates;
private double[] categoryProportions;
private boolean ratesKnown;
public static final int SHAPE_INDEX = 0;
public static final int INVAR_INDEX = 1;
public static final int PRESENT = 1;
public static final int ABSENT = 0;
public GammaSiteBMA(SubstitutionModel substitutionModel,
Parameter muParameter,
Variable<Double> logitInvar,
Variable<Double> logShape,
int categoryCount,
Variable<Integer> modelChoose){
super("GammaSiteBMA");
this.substitutionModel = substitutionModel;
addModel(substitutionModel);
addVariable(muParameter);
muParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1));
this.muParameter = muParameter;
addVariable(logShape);
logShape.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 1));
this.logShape = logShape;
addVariable(logitInvar);
logitInvar.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 1));
this.logitInvar = logitInvar;
//the modelChoose integer variable is of length = 2,
//where indices 0 and 1 indicates the presence or absence of alpha and pInvar repectively.
addVariable(modelChoose);
modelChoose.addBounds(new Bounds.Int(modelChoose,0,1));
this.modelChoose = modelChoose;
this.categoryCount = categoryCount +1;
categoryRates = new double[this.categoryCount];
categoryProportions = new double[this.categoryCount];
ratesKnown = false;
}
public boolean integrateAcrossCategories() {
return true;
}
public int getCategoryCount() {
return categoryCount;
}
public int getCategoryOfSite(int site) {
throw new IllegalArgumentException("Integrating across categories");
}
public double getRateForCategory(int category) {
synchronized (this) {
if (!ratesKnown) {
calculateCategoryRates();
}
}
final double mu = (muParameter != null) ? muParameter.getParameterValue(0) : 1.0;
return categoryRates[category] * mu;
}
public double[] getCategoryRates() {
synchronized (this) {
if (!ratesKnown) {
calculateCategoryRates();
}
}
final double mu = (muParameter != null) ? muParameter.getParameterValue(0) : 1.0;
final double[] rates = new double[categoryRates.length];
for (int i = 0; i < rates.length; i++) {
rates[i] = categoryRates[i] * mu;
}
return rates;
}
public void getTransitionProbabilities(double substitutions, double[] matrix) {
substitutionModel.getTransitionProbabilities(substitutions, matrix);
}
/**
* Get the expected proportion of sites in this category.
*
* @param category the category number
* @return the proportion.
*/
public double getProportionForCategory(int category) {
synchronized (this) {
if (!ratesKnown) {
calculateCategoryRates();
}
}
return categoryProportions[category];
}
/**
* Get an array of the expected proportion of sites in this category.
*
* @return an array of the proportion.
*/
public double[] getCategoryProportions() {
synchronized (this) {
if (!ratesKnown) {
calculateCategoryRates();
}
}
return categoryProportions;
}
/**
* discretization of gamma distribution with equal proportions in each
* category
*/
private void calculateCategoryRates() {
double propVariable = 1.0;
int cat = 0;
//System.out.println("modelChoose: "+ modelChoose.getValue(0)+" "+modelChoose.getValue(1));
//If including the site invariant parameter.
categoryRates[0] = 0.0;
//back transform from logit space
categoryProportions[0] = modelChoose.getValue(INVAR_INDEX)*(1/(1+Math.exp(-logitInvar.getValue(0))));
propVariable = 1.0 - categoryProportions[0];
cat = 1;
//If including the gamma shape parameter.
if (modelChoose.getValue(SHAPE_INDEX) == PRESENT) {
//back transform from log-space
final double a = Math.exp(logShape.getValue(0));
double mean = 0.0;
final int gammaCatCount = categoryCount - cat;
for (int i = 0; i < gammaCatCount; i++) {
categoryRates[i + cat] = GammaDistribution.quantile((2.0 * i + 1.0) / (2.0 * gammaCatCount), a, 1.0 / a);
//sum of the gamma categorical rates
mean += categoryRates[i + cat];
categoryProportions[i + cat] = propVariable / gammaCatCount;
}
//mean rate over all categories.
mean = (propVariable * mean) / gammaCatCount;
for (int i = 0; i < gammaCatCount; i++) {
//divide rates by the mean so that the average across all sites equals to 1.0
categoryRates[i + cat] /= mean;
}
} else {
final int gammaCatCount = categoryCount - cat;
for (int i = 0; i < gammaCatCount; i++) {
categoryRates[i + cat] = 1.0 / propVariable;
categoryProportions[i + cat] = propVariable/gammaCatCount;
}
}
/*for(int i = 0; i < categoryRates.length;i++){
System.out.print(categoryRates[i]+" ");
}
System.out.println();*/
ratesKnown = true;
}
/**
* Get the frequencyModel for this SiteModel.
*
* @return the frequencyModel.
*/
public FrequencyModel getFrequencyModel() {
return substitutionModel.getFrequencyModel();
}
/**
* Get the substitutionModel for this SiteModel.
*
* @return the substitutionModel.
*/
public SubstitutionModel getSubstitutionModel() {
return substitutionModel;
}
// *****************************************************************
// Interface ModelComponent
// *****************************************************************
protected void handleModelChangedEvent(Model model, Object object, int index) {
// Substitution model has changed so fire model changed event
listenerHelper.fireModelChanged(this, object, index);
}
protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
if (variable == logShape) {
ratesKnown = false;
} else if (variable == logitInvar) {
ratesKnown = false;
} else if (variable == modelChoose) {
//System.out.println("Changing model");
ratesKnown = false;
} else {
// is the muParameter and nothing needs to be done
}
listenerHelper.fireModelChanged(this, variable, index);
}
protected void storeState() {
/*System.out.println("store, modelChoose: "+ modelChoose.getValue(0)+" "+modelChoose.getValue(1) );
for(int i = 0; i < categoryRates.length;i++){
System.out.print(categoryRates[i]+" ");
}
System.out.println();*/
} // no additional state needs storing
protected void restoreState() {
//System.out.println("restore, modelChoose: "+ modelChoose.getValue(0)+" "+modelChoose.getValue(1) );
ratesKnown = false;
}
protected void acceptState() {
} // no additional state needs accepting
}