/*
* MultivariateNormalDistribution.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.math.distributions;
import dr.inference.model.GradientProvider;
import dr.inference.model.Likelihood;
import dr.math.MathUtils;
import dr.math.matrixAlgebra.*;
/**
* @author Marc Suchard
*/
public class MultivariateNormalDistribution implements MultivariateDistribution, GaussianProcessRandomGenerator,
GradientProvider {
public static final String TYPE = "MultivariateNormal";
private final double[] mean;
private final double[][] precision;
private double[][] variance = null;
private double[][] cholesky = null;
private Double logDet = null;
private final boolean hasSinglePrecision;
private final double singlePrecision;
public MultivariateNormalDistribution(double[] mean, double[][] precision) {
this.mean = mean;
this.precision = precision;
this.hasSinglePrecision = false;
this.singlePrecision = 1.0;
}
public MultivariateNormalDistribution(double[] mean, double singlePrecision) {
this.mean = mean;
this.hasSinglePrecision = true;
this.singlePrecision = singlePrecision;
final int dim = mean.length;
this.precision = new double[dim][dim];
for (int i = 0; i < dim; ++i) {
this.precision[i][i] = singlePrecision;
}
}
public String getType() {
return TYPE;
}
public double[][] getVariance() {
if (variance == null) {
variance = new SymmetricMatrix(precision).inverse().toComponents();
}
return variance;
}
public double[][] getCholeskyDecomposition() {
if (cholesky == null) {
cholesky = getCholeskyDecomposition(getVariance());
}
return cholesky;
}
public double getLogDet() {
if (logDet == null) {
logDet = Math.log(calculatePrecisionMatrixDeterminate(precision));
}
if (Double.isInfinite(logDet)) {
if (isDiagonal(precision)) {
logDet = logDetForDiagonal(precision);
}
}
return logDet;
}
private boolean isDiagonal(double x[][]) {
for (int i = 0; i < x.length; ++i) {
for (int j = i + 1; j < x.length; ++j) {
if (x[i][j] != 0.0) {
return false;
}
}
}
return true;
}
private double logDetForDiagonal(double x[][]) {
double logDet = 0;
for (int i = 0; i < x.length; ++i) {
logDet += Math.log(x[i][i]);
}
return logDet;
}
public double[][] getScaleMatrix() {
return precision;
}
public double[] getMean() {
return mean;
}
public double[] nextMultivariateNormal() {
return nextMultivariateNormalCholesky(mean, getCholeskyDecomposition(), 1.0);
}
public double[] nextMultivariateNormal(double[] x) {
return nextMultivariateNormalCholesky(x, getCholeskyDecomposition(), 1.0);
}
// Scale lives in variance-space
public double[] nextScaledMultivariateNormal(double[] mean, double scale) {
return nextMultivariateNormalCholesky(mean, getCholeskyDecomposition(), Math.sqrt(scale));
}
// Scale lives in variance-space
public void nextScaledMultivariateNormal(double[] mean, double scale, double[] result) {
nextMultivariateNormalCholesky(mean, getCholeskyDecomposition(), Math.sqrt(scale), result);
}
public static double calculatePrecisionMatrixDeterminate(double[][] precision) {
try {
return new Matrix(precision).determinant();
} catch (IllegalDimension e) {
throw new RuntimeException(e.getMessage());
}
}
public double logPdf(double[] x) {
if (hasSinglePrecision) {
return logPdf(x, mean, singlePrecision, 1.0);
} else {
return logPdf(x, mean, precision, getLogDet(), 1.0);
}
}
public double[] gradLogPdf(double[] x) {
if (hasSinglePrecision) {
return gradLogPdf(x, mean, singlePrecision);
} else {
return gradLogPdf(x, mean, precision);
}
}
public static double[] gradLogPdf(double[] x, double[] mean, double singlePrecision) {
final int dim = x .length;
final double[] gradient = new double[dim];
for (int i = 0; i < dim; ++i) {
gradient[i] = singlePrecision * (mean[i] - x[i]);
}
return gradient;
}
public static double[] gradLogPdf(double[] x, double[] mean, double[][] precision) {
final int dim = x.length;
final double[] gradient = new double[dim];
final double[] delta = new double[dim];
for (int i = 0; i < dim; ++i) {
delta[i] = mean[i] - x[i];
}
for (int i = 0; i < dim; ++i) {
double sum = 0;
for (int j = 0; j <dim; ++j) {
sum += precision[i][j] * delta[j];
}
gradient[i] = sum;
}
return gradient;
}
// scale only modifies precision
// in one dimension, this is equivalent to:
// PDF[NormalDistribution[mean, Sqrt[scale]*Sqrt[1/precison]], x]
public static double logPdf(double[] x, double[] mean, double[][] precision,
double logDet, double scale) {
if (logDet == Double.NEGATIVE_INFINITY)
return logDet;
final int dim = x.length;
final double[] delta = new double[dim];
final double[] tmp = new double[dim];
for (int i = 0; i < dim; i++) {
delta[i] = x[i] - mean[i];
}
for (int i = 0; i < dim; i++) {
for (int j = 0; j < dim; j++) {
tmp[i] += delta[j] * precision[j][i];
}
}
double SSE = 0;
for (int i = 0; i < dim; i++)
SSE += tmp[i] * delta[i];
return dim * logNormalize + 0.5 * (logDet - dim * Math.log(scale) - SSE / scale); // There was an error here.
// Variance = (scale * Precision^{-1})
}
/* Equal precision, independent dimensions */
public static double logPdf(double[] x, double[] mean, double precision, double scale) {
final int dim = x.length;
double SSE = 0;
for (int i = 0; i < dim; i++) {
double delta = x[i] - mean[i];
SSE += delta * delta;
}
return dim * logNormalize + 0.5 * (dim * (Math.log(precision) - Math.log(scale)) - SSE * precision / scale);
}
private static double[][] getInverse(double[][] x) {
return new SymmetricMatrix(x).inverse().toComponents();
}
private static double[][] getCholeskyDecomposition(double[][] variance) {
double[][] cholesky;
try {
cholesky = (new CholeskyDecomposition(variance)).getL();
} catch (IllegalDimension illegalDimension) {
throw new RuntimeException("Attempted Cholesky decomposition on non-square matrix");
}
return cholesky;
}
public static double[] nextMultivariateNormalPrecision(double[] mean, double[][] precision) {
return nextMultivariateNormalVariance(mean, getInverse(precision));
}
public static double[] nextMultivariateNormalVariance(double[] mean, double[][] variance) {
return nextMultivariateNormalVariance(mean, variance, 1.0);
}
public static double[] nextMultivariateNormalVariance(double[] mean, double[][] variance, double scale) {
return nextMultivariateNormalCholesky(mean, getCholeskyDecomposition(variance), Math.sqrt(scale));
}
public static double[] nextMultivariateNormalCholesky(double[] mean, double[][] cholesky) {
return nextMultivariateNormalCholesky(mean, cholesky, 1.0);
}
public static double[] nextMultivariateNormalCholesky(double[] mean, double[][] cholesky, double sqrtScale) {
double[] result = new double[mean.length];
nextMultivariateNormalCholesky(mean, cholesky, sqrtScale, result);
return result;
}
public static void nextMultivariateNormalCholesky(double[] mean, double[][] cholesky, double sqrtScale, double[] result) {
final int dim = mean.length;
System.arraycopy(mean, 0, result, 0, dim);
double[] epsilon = new double[dim];
for (int i = 0; i < dim; i++)
epsilon[i] = MathUtils.nextGaussian() * sqrtScale;
for (int i = 0; i < dim; i++) {
for (int j = 0; j <= i; j++) {
result[i] += cholesky[i][j] * epsilon[j];
// caution: decomposition returns lower triangular
}
}
}
public static void nextMultivariateNormalCholesky(final WrappedVector mean, final double[][] cholesky,
final double sqrtScale, final WrappedVector result,
final double[] epsilon) {
final int dim = mean.getDim();
for (int i = 0; i < dim; i++) {
epsilon[i] = MathUtils.nextGaussian() * sqrtScale;
}
for (int i = 0; i < dim; i++) {
double x = mean.get(i);
for (int j = 0; j <= i; j++) {
x += cholesky[i][j] * epsilon[j];
// caution: decomposition returns lower triangular
}
result.set(i, x);
}
}
public static void nextMultivariateNormalCholesky(final double[] mean, final int meanOffset, final double[][] cholesky,
final double sqrtScale, final double[] result, final int resultOffset,
final double[] epsilon) {
final int dim = epsilon.length;
System.arraycopy(mean, meanOffset, result, resultOffset, dim);
for (int i = 0; i < dim; i++)
epsilon[i] = MathUtils.nextGaussian() * sqrtScale;
for (int i = 0; i < dim; i++) {
for (int j = 0; j <= i; j++) {
result[resultOffset + i] += cholesky[i][j] * epsilon[j];
// caution: decomposition returns lower triangular
}
}
}
// TODO should be a junit test
public static void main(String[] args) {
testPdf();
testRandomDraws();
}
public static void testPdf() {
double[] start = {1, 2};
double[] stop = {0, 0};
double[][] precision = {{2, 0.5}, {0.5, 1}};
double scale = 0.2;
System.err.println("logPDF = " + logPdf(start, stop, precision, Math.log(calculatePrecisionMatrixDeterminate(precision)), scale));
System.err.println("Should = -19.94863\n");
System.err.println("logPDF = " + logPdf(start, stop, 2, 0.2));
System.err.println("Should = -24.53529\n");
}
public static void testRandomDraws() {
double[] start = {1, 2};
double[][] precision = {{2, 0.5}, {0.5, 1}};
int length = 100000;
System.err.println("Random draws (via precision) ...");
double[] mean = new double[2];
double[] SS = new double[2];
double[] var = new double[2];
double ZZ = 0;
for (int i = 0; i < length; i++) {
double[] draw = nextMultivariateNormalPrecision(start, precision);
for (int j = 0; j < 2; j++) {
mean[j] += draw[j];
SS[j] += draw[j] * draw[j];
}
ZZ += draw[0] * draw[1];
}
for (int j = 0; j < 2; j++) {
mean[j] /= length;
SS[j] /= length;
var[j] = SS[j] - mean[j] * mean[j];
}
ZZ /= length;
ZZ -= mean[0] * mean[1];
System.err.println("Mean: " + new Vector(mean));
System.err.println("TRUE: [ 1 2 ]\n");
System.err.println("MVar: " + new Vector(var));
System.err.println("TRUE: [ 0.571 1.14 ]\n");
System.err.println("Covv: " + ZZ);
System.err.println("TRUE: -0.286");
}
public static final double logNormalize = -0.5 * Math.log(2.0 * Math.PI);
// RandomGenerator interface
public Object nextRandom() {
return nextMultivariateNormal();
}
public double logPdf(Object x) {
double[] v = (double[]) x;
return logPdf(v);
}
@Override
public Likelihood getLikelihood() {
return null;
}
@Override
public int getDimension() { return mean.length; }
@Override
public double[] getGradientLogDensity(Object x) {
return gradLogPdf((double[]) x);
}
@Override
public double[][] getPrecisionMatrix() {
return precision;
}
}