/*
* MultivariatePolyaDistributionLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.math.distributions;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.*;
import dr.math.GammaFunction;
import dr.xml.*;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
/**
* Package: MultivariatePolyaDistributionLikelihood
* Description:
* this class provides a model for over-dispersed multinomial counts. The model follows Dirichlet-Multinomial distribution with
* multinomial parameters integrated out analytically. This model is also known as Multivariate Polya distribution.
* Standard parametrization involves k intensities a_i's. This implementation uses the standard parametrization internally, but allows
* for re-parametrization as frequencies (k-1 df) and dispersion parameters, where a = sum_i=1^k a_i is dispersion and f_i = a_i/a
* <p/>
* Created by
*
* @author Alexander V. Alekseyenko (alexander.alekseyenko@gmail.com)
* Date: Sep 23, 2010
* Time: 12:25:14 PM
*/
public class MultivariatePolyaDistributionLikelihood extends AbstractModel implements Likelihood {
protected Parameter frequencies;
protected Parameter dispersion;
protected Parameter alphas;
protected boolean usingAlphas;
protected boolean isAlphasKnown;
protected MatrixParameter data;
protected double fixedNorm;
protected double variableNorm;
protected double storedFixedNorm;
protected double storedVariableNorm;
protected double logLikelihood;
protected double storedLogLikelihood;
protected boolean isLogLikelihoodKnown;
protected boolean isFixedNormKnown;
protected boolean isVariableNormKnown;
protected double rowSums[];
public MultivariatePolyaDistributionLikelihood(String modelID, MatrixParameter data, Parameter frequencies, Parameter dispersion) {
super(modelID);
this.frequencies = frequencies;
this.dispersion = dispersion;
this.alphas = new Parameter.Default(frequencies.getDimension());
computeAlphas();
this.data = data;
isFixedNormKnown = false;
isVariableNormKnown = false;
addVariable(this.frequencies);
addVariable(this.dispersion);
addVariable(this.data);
if (this.alphas.getDimension() != data.getColumnDimension()) {
System.err.println("Dimensions of the frequency vector and number of columns do not match!");
}
}
public MultivariatePolyaDistributionLikelihood(String modelID, MatrixParameter data, Parameter alphas) {
super(modelID);
this.alphas = alphas;
isAlphasKnown = true;
usingAlphas = true;
this.frequencies = new Parameter.Default(alphas.getDimension());
this.dispersion = new Parameter.Default(1);
this.data = data;
isFixedNormKnown = false;
isVariableNormKnown = false;
addVariable(this.alphas);
addVariable(this.data);
if (this.alphas.getDimension() != data.getColumnDimension()) {
System.err.println("Dimensions of the frequency vector and number of columns do not match!");
}
}
/* Compute alphas from frequencies and dispersion
*/
protected void computeAlphas(){
double disp=dispersion.getParameterValue(0);
double[] freqs = frequencies.getParameterValues();
for(int i=0; i<alphas.getDimension(); ++i){
alphas.setParameterValueQuietly(i, disp*freqs[i]);
}
alphas.setParameterValueNotifyChangedAll(0, alphas.getParameterValue(0));
isAlphasKnown = true;
}
public MultivariatePolyaDistributionLikelihood(String modelID) {
super(modelID);
}
public double calculateLogLikelihood() {
// R code for this function:
// //assuming X[,1] is row totals, lfactX is log factorial of X //
// logLikes = sapply(1:n, function(i) lfactX[subset,1][i] - sum(lfactX[subset,][i,2:p1]) +
// lgamma(sum(alpha)) - lgamma(X[subset,1][i] + sum(alpha)) + sum(lgamma(X[subset,2:p1][i,] + alpha)) - sum(lgamma(alpha)));
// sum(logLikes)
if(!isAlphasKnown) computeAlphas();
if (!isFixedNormKnown) {
computeFixedNorm();
}
if (!isVariableNormKnown) {
computeVariableNorm();
}
double logP = fixedNorm + variableNorm;
double disp = 0;
double[] a = alphas.getParameterValues();
for(int i = 0; i< alphas.getDimension(); ++i){
disp = disp + a[i];
}
for (int i = 0; i < data.getRowDimension(); ++i) {
for (int j = 0; j < data.getColumnDimension(); ++j) {
logP += GammaFunction.lnGamma(data.getParameterValue(i, j) + a[j]);
}
logP -= GammaFunction.lnGamma(rowSums[i] + disp);
}
return logP;
}
protected void computeFixedNorm() {
rowSums = new double[data.getRowDimension()];
for (int i = 0; i < data.getRowDimension(); ++i) {
rowSums[i] = 0;
for (int j = 0; j < data.getColumnDimension(); ++j) {
rowSums[i] += data.getParameterValue(i, j);
}
}
fixedNorm = 0;
for (int i = 0; i < data.getRowDimension(); ++i) {
for (int j = 0; j < data.getColumnDimension(); ++j) {
fixedNorm -= GammaFunction.lnGamma(data.getParameterValue(i, j) + 1);
}
fixedNorm += GammaFunction.lnGamma(rowSums[i] + 1);
}
isFixedNormKnown = true;
}
protected void computeVariableNorm() {
double disp = 0;
double[] a = alphas.getParameterValues();
for(int i = 0; i< alphas.getDimension(); ++i){
disp = disp + a[i];
}
variableNorm = GammaFunction.lnGamma(disp);
for (int i = 0; i < alphas.getDimension(); ++i) {
variableNorm -= GammaFunction.lnGamma(a[i]);
}
variableNorm *= data.getRowDimension();
}
protected void handleModelChangedEvent(Model model, Object object, int index) {
}
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
if (variable.getVariableName().equals(frequencies.getVariableName()) || variable.getVariableName().equals(dispersion.getVariableName())) {
isAlphasKnown = false;
isVariableNormKnown = false;
} else if (variable.getVariableName().equals(data.getVariableName())) {
isFixedNormKnown = false;
}
else if(variable.getVariableName().equals(alphas.getVariableName())){
isVariableNormKnown = false;
}
}
protected void storeState() {
storedVariableNorm = variableNorm;
storedFixedNorm = fixedNorm;
storedLogLikelihood = logLikelihood;
}
protected void restoreState() {
variableNorm = storedVariableNorm;
fixedNorm = storedFixedNorm;
logLikelihood = storedLogLikelihood;
if(!usingAlphas) computeAlphas();
}
protected void acceptState() {
}
public Model getModel() {
return this;
}
public double getLogLikelihood() {
if (!isLogLikelihoodKnown) {
logLikelihood = calculateLogLikelihood();
}
return logLikelihood;
}
public void makeDirty() {
isLogLikelihoodKnown = false;
isVariableNormKnown = false;
isFixedNormKnown = false;
}
public String prettyName() {
return "Multivariate Polya Distribution Likelihood";
}
public boolean evaluateEarly() {
return false;
}
@Override
public Set<Likelihood> getLikelihoodSet() {
return new HashSet<Likelihood>(Arrays.asList(this));
}
public void setUsed() {
}
public LogColumn[] getColumns() {
return new LogColumn[]{
new NumberColumn(this.getId()) {
public double getDoubleValue() {
return getLogLikelihood();
}
}
};
}
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public String getParserName() {
return MVPLIKE;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
MatrixParameter data;
Parameter dispersion;
Parameter frequencies;
Parameter rates;
if (xo.hasChildNamed(DATA)) {
data = (MatrixParameter) xo.getChild(DATA).getChild(MatrixParameter.class);
} else {
throw new XMLParseException("Missing data element!");
}
if (xo.hasChildNamed(RATES)) {
rates = (Parameter) xo.getChild(RATES).getChild(Parameter.class);
if (rates.getDimension() != data.getColumnDimension()) {
throw new XMLParseException("The number of data columns must match the dimension of " + RATES
+ " parameter (" + data.getColumnDimension() + " != " + rates.getDimension() + "!");
}
}
else if (xo.hasChildNamed(FREQ)) {
frequencies = (Parameter) xo.getChild(FREQ).getChild(Parameter.class);
if (xo.hasChildNamed(DISPERSION)) {
dispersion = (Parameter) xo.getChild(DISPERSION).getChild(Parameter.class);
} else {
throw new XMLParseException(DISPERSION + " element has to be specified when using " + FREQ
+" parametrization");
}
if (dispersion.getDimension() != 1) {
throw new XMLParseException("Dispersion parameter must be of dimmension exactly 1!");
}
if (frequencies.getDimension() != data.getColumnDimension()) {
throw new XMLParseException("The number of data columns must match the dimension of "+ FREQ
+ " parameter (" + data.getColumnDimension() + " != " + frequencies.getDimension() + "!");
}
return new MultivariatePolyaDistributionLikelihood(MVPLIKE, data, frequencies, dispersion);
} else {
throw new XMLParseException("Either " + FREQ + " or " + RATES + "element has to be specified!");
}
return new MultivariatePolyaDistributionLikelihood(MVPLIKE, data, rates);
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "A matrix parameter constructed from its component parameters.";
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
new ElementRule(DATA, new XMLSyntaxRule[]{new ElementRule(MatrixParameter.class)}, false),
new XORRule(new ElementRule(RATES, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, false),
new ElementRule(FREQ, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, false)),
new ElementRule(DISPERSION, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true),
};
public Class getReturnType() {
return MatrixParameter.class;
}
};
public static final String MVPLIKE = "mvPolyaLikelihood";
public static final String DATA = "data";
public static final String DISPERSION = "dispersion";
public static final String FREQ = "frequencies";
public static final String RATES = "alpha";
}