/* * NegativeBinomialDistribution.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.math.distributions; import dr.math.ErrorFunction; import org.apache.commons.math.MathException; import org.apache.commons.math.special.Beta; import dr.math.*; /** * @author Trevor Bedford * @version $Id$ */ public class NegativeBinomialDistribution implements Distribution { double mean; double alpha; public NegativeBinomialDistribution(double mean, double alpha) { this.mean = mean; this.alpha = alpha; } public double pdf(double x) { return pdf(x, mean, alpha); } public double logPdf(double x) { return logPdf(x, mean, alpha); } public double cdf(double x) { return cdf(x, mean, alpha); } public double quantile(double y) { // TB - I'm having trouble implementing this // LM - A first stab using simple minimisation to invert the function (under absolute loss) // Implementation based on the qnbinom.c function used in R final double stdev = Math.sqrt(mean + (mean * mean * alpha)); final double r = -1 * (mean*mean) / (mean - stdev*stdev); final double p = mean / (stdev*stdev); final double prob = y; final double Q = 1.0 / p; final double P = (1.0 - p) * Q; final double gamma = (Q + P)/stdev; final double z = Math.sqrt(2.0) * ErrorFunction.inverseErf(2.0 * y - 1.0); final double crudeY = mean + stdev * (z + gamma * (z*z - 1) / 6); UnivariateFunction f = new UnivariateFunction() { double tent = Double.NaN; public double evaluate(final double argument) { try { tent = Beta.regularizedBeta(p, r, argument+1); } catch (MathException e) { return Double.NaN; } double score = Math.abs(tent-prob); return score; } public int getNumArguments() { return 1; } public double getLowerBound() { // 20% window should cut it. Probably too large even... return Math.min(crudeY - .2*crudeY, 0); } public double getUpperBound() { return crudeY + .2*crudeY; } }; UnivariateMinimum minimum = new UnivariateMinimum(); double q = minimum.findMinimum(f); return Math.ceil(q); } public double mean() { return mean; } public double variance() { return mean + (mean * mean * alpha); } public UnivariateFunction getProbabilityDensityFunction() { throw new RuntimeException(); } public static double pdf(double x, double mean, double alpha) { if (x < 0) return 0; return Math.exp(logPdf(x, mean, alpha)); } public static double logPdf(double x, double mean, double alpha) { if (x < 0) return Double.NEGATIVE_INFINITY; // double r = -1 * (mean*mean) / (mean - stdev*stdev); // double p = mean / (stdev*stdev); // return Math.log(Math.pow(1-p,x)) + Math.log(Math.pow(p, r)) + GammaFunction.lnGamma(r+x) - GammaFunction.lnGamma(r) - GammaFunction.lnGamma(x+1); double theta = 1.0 / alpha; double p = theta / (theta + mean); return Math.log(1 - p) * x + Math.log(p) * theta + GammaFunction.lnGamma(theta + x) - GammaFunction.lnGamma(theta) - GammaFunction.lnGamma(x+1); } public static double cdf(double x, double mean, double alpha) { double theta = 1.0 / alpha; double p = theta / (theta + mean); try { return Beta.regularizedBeta(p, theta, x+1); } catch (MathException e) { // AR - throwing exceptions deep in numerical code causes trouble. Catching runtime // exceptions is bad. Better to return NaN and let the calling code deal with it. return Double.NaN; // throw MathRuntimeException.createIllegalArgumentException( // "Couldn't calculate beta cdf for alpha = " + alpha + ", beta = " + beta + ": " +e.getMessage()); } } public static void main(String[] args) { System.out.println("Test negative binomial"); double mean = 5; double stdev = 5; // double r = -1 * (mean*mean) / (mean - stdev*stdev); double alpha = (stdev * stdev - mean) / (mean * mean); NegativeBinomialDistribution dist = new NegativeBinomialDistribution(5, alpha); System.out.println("Mean 5, sd 5, x 5, pdf 0.074487, logPdf -2.59713, median 4"); System.out.println("pdf = " + dist.pdf(5)); System.out.println("quantile(0.5) aka median = " + dist.quantile(0.5)); System.out.println("logPdf = " + dist.logPdf(5)); } }