/*
* NormalKDEDistribution.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.math.distributions;
import dr.math.ComplexArray;
import dr.math.FastFourierTransform;
import dr.stats.DiscreteStatistics;
import dr.util.HeapSort;
import java.util.Random;
/**
* @author Marc A. Suchard
*/
public class NormalKDEDistribution extends KernelDensityEstimatorDistribution {
public static final int MINIMUM_GRID_SIZE = 512;
public NormalKDEDistribution(Double[] sample) {
this(sample, null, null, null);
}
public NormalKDEDistribution(Double[] sample, Double lowerBound, Double upperBound, Double bandWidth) {
this(sample, lowerBound, upperBound, bandWidth, 3.0, MINIMUM_GRID_SIZE);
}
public NormalKDEDistribution(Double[] sample, Double lowerBound, Double upperBound, Double bandWidth,
int n) {
this(sample, lowerBound, upperBound, bandWidth, 3.0, n);
}
public NormalKDEDistribution(Double[] sample, Double lowerBound, Double upperBound, Double bandWidth,
double cut, int n) {
super(sample, lowerBound, upperBound, bandWidth);
this.gridSize = Math.max(n, MINIMUM_GRID_SIZE);
if (this.gridSize > MINIMUM_GRID_SIZE) {
this.gridSize = (int) Math.pow(2, Math.ceil(Math.log(this.gridSize) / Math.log(2.0)));
}
this.cut = cut;
from = DiscreteStatistics.min(super.sample) - this.cut * this.bandWidth;
to = DiscreteStatistics.max(super.sample) + this.cut * this.bandWidth;
lo = from - 4.0 * this.bandWidth;
up = to + 4.0 * this.bandWidth;
densityKnown = false;
}
public double getFromPoint() {
return from;
}
public double getToPoint() {
return to;
}
/**
* Returns a linear approximation evaluated at pt
* @param x data (assumed sorted increasingly
* @param y data
* @param pt evaluation point
* @param low return value if pt < x
* @param high return value if pt > x
* @return evaluated coordinate
*/
private double linearApproximate(double[] x, double[] y, double pt, double low, double high) {
int i = 0;
int j = x.length - 1;
if (pt < x[i]) {
return low;
}
if (pt > x[j]) {
return high;
}
// Bisection search
while (i < j - 1) {
int ij = (i + j) / 2;
if (pt < x[ij]) {
j = ij;
} else {
i = ij;
}
}
if (pt == x[j]) {
return y[j];
}
if (pt == x[i]) {
return y[i];
}
return y[i] + (y[j] - y[i]) * ((pt - x[i]) / (x[j] - x[i]));
}
private double[] rescaleAndTrim(double[] x) {
final int length = x.length / 2;
final double scale = 1.0 / x.length;
double[] out = new double[length];
for (int i = 0; i < length; ++i) {
out[i] = x[i] * scale;
if (out[i] < 0) {
out[i] = 0;
}
}
return out;
}
private double[] massdist(double[] x,
// double[] xmass,
double xlow, double xhigh, int ny) {
int nx = x.length;
double[] y = new double[ny * 2];
final int ixmin = 0;
final int ixmax = ny - 2;
final double xdelta = (xhigh - xlow) / (ny - 1);
for (int i = 0; i < ny; ++i) {
y[i] = 0.0;
}
final double xmi = 1.0 / nx;
for (int i = 0; i < nx; ++i) {
final double xpos = (x[i] - xlow) / xdelta;
final int ix = (int) Math.floor(xpos);
final double fx = xpos - ix;
// final double xmi = xmass[i];
if (ixmin <= ix && ix <= ixmax) {
y[ix] += (1 - fx) * xmi;
y[ix + 1] += fx * xmi;
} else if (ix == -1) {
y[0] += fx * xmi;
} else if (ix == ixmax + 1) {
y[ix] += (1 - fx) * xmi;
}
}
return y;
}
/**
* Override for different kernels
* @param ordinates the points in complex space
* @param bandWidth predetermined bandwidth
*/
protected void fillKernelOrdinates(ComplexArray ordinates, double bandWidth) {
final int length = ordinates.length;
final double a = 1.0 / (Math.sqrt(2.0 * Math.PI) * bandWidth);
final double precision = -0.5 / (bandWidth * bandWidth);
for (int i = 0; i < length; i++) {
final double x = ordinates.real[i];
ordinates.real[i] = a * Math.exp(x * x * precision);
}
}
protected void computeDensity() {
makeOrdinates();
transformData();
densityKnown = true;
}
private void transformData() {
ComplexArray Y = new ComplexArray(massdist(this.sample, lo, up, this.gridSize));
FastFourierTransform.fft(Y, false);
ComplexArray product = Y.product(kOrdinates);
FastFourierTransform.fft(product, true);
densityPoints = rescaleAndTrim(product.real);
}
private void makeOrdinates() {
final int length = 2 * gridSize;
if (kOrdinates == null) {
kOrdinates = new ComplexArray(new double[length]);
}
// Fill with grid values
final double max = 2.0 * (up - lo);
double value = 0;
final double inc = max / (length - 1);
for (int i = 0; i <= gridSize; i++) {
kOrdinates.real[i] = value;
value += inc;
}
for (int i = gridSize + 1; i < length; i++) {
kOrdinates.real[i] = -kOrdinates.real[length - i];
}
fillKernelOrdinates(kOrdinates, bandWidth);
FastFourierTransform.fft(kOrdinates, false);
kOrdinates.conjugate();
// Make x grid
xPoints = new double[gridSize];
double x = lo;
double delta = (up - lo) / (gridSize - 1);
for (int i = 0; i < gridSize; i++) {
xPoints[i] = x;
x += delta;
}
}
@Override
protected double evaluateKernel(double x) {
if (!densityKnown) {
computeDensity();
}
return linearApproximate(xPoints, densityPoints, x, 0.0, 0.0);
}
@Override
protected void processBounds(Double lowerBound, Double upperBound) {
if ((lowerBound != null && lowerBound != Double.NEGATIVE_INFINITY) ||
(upperBound != null && upperBound != Double.POSITIVE_INFINITY)) {
throw new RuntimeException("NormalKDEDistribution must be unbounded");
}
}
@Override
protected void setBandWidth(Double bandWidth) {
if (bandWidth == null) {
// Default bandwidth
this.bandWidth = bandwidthNRD(sample);
} else
this.bandWidth = bandWidth;
densityKnown = false;
}
// bandwidth.nrd =
// function (x)
// {
// r <- quantile(x, c(0.25, 0.75))
// h <- (r[2] - r[1])/1.34
// 4 * 1.06 * min(sqrt(var(x)), h) * length(x)^(-1/5)
// }
public double bandwidthNRD(double[] x) {
int[] indices = new int[x.length];
HeapSort.sort(x, indices);
final double h =
(DiscreteStatistics.quantile(0.75, x, indices) - DiscreteStatistics.quantile(0.25, x, indices)) / 1.34;
return 1.06 *
Math.min(Math.sqrt(DiscreteStatistics.variance(x)), h) *
Math.pow(x.length, -0.2);
}
private ComplexArray kOrdinates;
private double[] xPoints;
private double[] densityPoints;
private int gridSize;
private double cut;
private double from;
private double to;
private double lo;
private double up;
private boolean densityKnown = false;
public static void main(String[] args) {
long start = System.currentTimeMillis();
Random random = new Random(1234);
Double[] samples = new Double[10000000];
for (int i = 0; i < samples.length; i++) {
samples[i] = random.nextDouble();
}
NormalKDEDistribution nKDE = new NormalKDEDistribution(samples);
for (int i = 0; i < 100; i++) {
nKDE.evaluateKernel(random.nextDouble());
}
long end = System.currentTimeMillis();
System.out.println("Time: " + (end-start));
}
}