/*
* DirichletDistribution.java
*
* Copyright (c) 2002-2016 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.math.distributions;
import dr.math.GammaFunction;
/**
* @author Marc A. Suchard
* @author Guy Baele
*/
public class DirichletDistribution implements MultivariateDistribution {
public static final String TYPE = "dirichletDistribution";
public static final boolean DEBUG = false;
//4.0 != 3.9999999999999996
//Other BEAST classes uses 1E-8 or 1E-6
public static final double ACCURACY_THRESHOLD = 1E-12;
private double[] counts;
private double countSum = 0.0;
private double countParameterSum;
private int dim;
private boolean sumToNumberOfElements;
private double logNormalizingConstant;
public DirichletDistribution(double[] counts, boolean sumToNumberOfElements) {
this.counts = counts;
this.sumToNumberOfElements = sumToNumberOfElements;
if (this.sumToNumberOfElements) {
countParameterSum = (double)counts.length;
} else {
countParameterSum = 1.0;
}
dim = counts.length;
for (int i = 0; i < dim; i++) {
countSum += counts[i];
}
computeNormalizingConstant();
}
public DirichletDistribution(double[] counts, double countParameterSum) {
this.counts = counts;
this.countParameterSum = countParameterSum;
dim = counts.length;
for (int i = 0; i < dim; i++) {
countSum += counts[i];
}
computeNormalizingConstant();
}
private void computeNormalizingConstant() {
logNormalizingConstant = GammaFunction.lnGamma(countSum);
for (int i = 0; i < dim; i++) {
logNormalizingConstant -= GammaFunction.lnGamma(counts[i]);
}
logNormalizingConstant -= dim * Math.log(countParameterSum);
}
public double logPdf(double[] x) {
if (x.length != dim) {
throw new IllegalArgumentException("data array is of the wrong dimension");
}
double logPDF = logNormalizingConstant;
double parameterSum = 0.0;
for (int i = 0; i < dim; i++) {
logPDF += (counts[i] - 1) * (Math.log(x[i]) - Math.log(countParameterSum));
parameterSum += x[i];
// if ((!sumToNumberOfElements && x[i] >= 1.0) || x[i] <= 0.0) {
// if (DEBUG) {
// System.out.println("Invalid parameter value");
// }
// logPDF = Double.NEGATIVE_INFINITY;
// break;
// }
}
if (Math.abs(parameterSum - countParameterSum) > ACCURACY_THRESHOLD) {
if (DEBUG) {
System.out.println("Parameters do not sum to " + countParameterSum);
for (int i = 0; i < dim; i++) {
System.out.println("x[" + i + "] = " + x[i]);
}
System.out.println("Current sum = " + parameterSum);
}
logPDF = Double.NEGATIVE_INFINITY;
}
return logPDF;
}
public double[][] getScaleMatrix() {
return null;
}
public double[] getMean() {
double[] mean = new double[dim];
for (int i = 0; i < dim; i++)
mean[i] = counts[i] / countSum;
return mean;
}
public String getType() {
return TYPE;
}
public static void main(String[] args) {
//Test Dirichlet distribution for the standard n-simplex
System.out.println("Test Dirichlet distribution for the standard n-simplex");
//R: log(ddirichlet(c(0.5,0.2,0.3),c(1,2,3))) = 0.07696104
double[] counts = new double[3];
counts[0] = 1.0;
counts[1] = 2.0;
counts[2] = 3.0;
DirichletDistribution dd = new DirichletDistribution(counts, false);
double[] parameterValues = new double[3];
parameterValues[0] = 0.5;
parameterValues[1] = 0.2;
parameterValues[2] = 0.3;
System.out.println(dd.logPdf(parameterValues));
//Test Scaled Dirichlet distribution
System.out.println("Test Scaled Dirichlet distribution");
//R: log(ddirichlet(c(1.5,0.6,0.9)/3,c(1,2,3))/(3^3)) = -3.218876
dd = new DirichletDistribution(counts, true);
parameterValues[0] = 1.5;
parameterValues[1] = 0.6;
parameterValues[2] = 0.9;
System.out.println(dd.logPdf(parameterValues));
parameterValues[0] = 1.0;
parameterValues[1] = 1.0;
parameterValues[2] = 1.0;
System.out.println(dd.logPdf(parameterValues));
counts = new double[4];
counts[0] = 1.0;
counts[1] = 1.0;
counts[2] = 1.0;
counts[3] = 1.0;
dd = new DirichletDistribution(counts, true);
parameterValues = new double[4];
parameterValues[0] = 0.5;
parameterValues[1] = 1.2;
parameterValues[2] = 1.3;
parameterValues[3] = 1.0;
System.out.println(dd.logPdf(parameterValues));
parameterValues[0] = 1.0;
parameterValues[1] = 1.0;
parameterValues[2] = 1.0;
parameterValues[3] = 1.0;
System.out.println(dd.logPdf(parameterValues));
}
}