/*
* MultiDimensionalScalingCoreImpl2.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.multidimensionalscaling;
import dr.math.distributions.NormalDistribution;
/**
* MultiDimensionalScalingCoreImpl
*
* @author Andrew Rambaut
* @author Marc Suchard
* @version $Id$
*
* $HeadURL$
*
* $LastChangedBy$
* $LastChangedDate$
* $LastChangedRevision$
*/
public class MultiDimensionalScalingCoreImpl implements MultiDimensionalScalingCore {
@Override
public void initialize(int embeddingDimension, int locationCount, long flags) {
this.embeddingDimension = embeddingDimension;
this.locationCount = locationCount;
this.observationCount = (locationCount * (locationCount - 1)) / 2;
observations = new double[locationCount][locationCount];
increments = new double[locationCount][locationCount];
storedIncrements = null;
incrementsKnown = false;
sumOfIncrementsKnown = false;
isLeftTruncated = (flags & MultiDimensionalScalingCore.LEFT_TRUNCATION) != 0;
updatedLocation = -1;
locations = new double[locationCount][embeddingDimension];
storedLocations = new double[locationCount][embeddingDimension];
}
@Override
public void setPairwiseData(double[] observations) {
if (observations.length != (locationCount * locationCount)) {
throw new RuntimeException("Observation data is not the correct dimension");
}
int k = 0;
for (int i = 0; i < locationCount; i++) {
System.arraycopy(observations, k, this.observations[i], 0, locationCount);
k += locationCount;
}
}
@Override
public double[] getPairwiseData() {
double[] data = new double[locationCount * locationCount];
int k = 0;
for (int i = 0; i < locationCount; ++i) {
System.arraycopy(observations[i], 0, data, k, locationCount);
k += locationCount;
}
return data;
}
@Override
public void setParameters(double[] parameters) {
precision = parameters[0];
// Handle truncations
if (isLeftTruncated) {
incrementsKnown = false;
sumOfIncrementsKnown = false;
}
}
@Override
public void updateLocation(int locationIndex, double[] location) {
if (updatedLocation != -1 || locationIndex == -1) {
// more than one location updated - do a full recomputation
incrementsKnown = false;
storedIncrements = null;
}
if (locationIndex != -1) {
updatedLocation = locationIndex;
if (location.length != embeddingDimension) {
throw new RuntimeException("Location is not the correct dimension");
}
System.arraycopy(location, 0, locations[locationIndex], 0, embeddingDimension);
} else {
if (location.length != embeddingDimension * locationCount) {
throw new RuntimeException("Location is the not correct dimension");
}
int offset = 0;
for (int i = 0; i < locationCount; ++i) {
System.arraycopy(location, offset, locations[i], 0, embeddingDimension);
offset += embeddingDimension;
}
}
sumOfIncrementsKnown = false;
}
@Override
public double calculateLogLikelihood() {
if (!sumOfIncrementsKnown) {
if (!incrementsKnown) {
computeSumOfSquaredResiduals();
} else {
updateSumOfSquaredResiduals();
if (REPORT_ROUNDOFF) {
// Report round-off error
double storedSumOfSquaredResults = sumOfIncrements;
computeSumOfSquaredResiduals();
if (Math.abs(storedSumOfSquaredResults - sumOfIncrements) > 1E-6) {
System.err.println(storedSumOfSquaredResults);
System.err.println(sumOfIncrements);
System.err.println(storedSumOfSquaredResults - sumOfIncrements);
System.err.println("");
}
}
}
sumOfIncrementsKnown = true;
}
double logLikelihood = 0.5 * (Math.log(precision) - Math.log(2 * Math.PI)) * observationCount;
if (isLeftTruncated) {
logLikelihood -= sumOfIncrements; // If truncated, then values on difference scale
} else {
logLikelihood -= 0.5 * precision * sumOfIncrements;
}
return logLikelihood;
}
@Override
public void storeState() {
// Handle residuals
storedSumOfIncrements = sumOfIncrements;
storedIncrements = null;
// Handle locations
for (int i = 0; i < locationCount; i++) {
System.arraycopy(locations[i], 0 , storedLocations[i], 0, embeddingDimension);
}
updatedLocation = -1;
// Handle precision
storedPrecision = precision;
}
@Override
public void restoreState() {
// Handle residuals
sumOfIncrements = storedSumOfIncrements;
sumOfIncrementsKnown = true;
if (storedIncrements != null) {
System.arraycopy(storedIncrements, 0 , increments[updatedLocation], 0, locationCount);
// for (int j = 0; j < locationCount; j++) { // Do not write transposed values
// increments[j][updatedLocation] = storedIncrements[j];
// }
incrementsKnown = true;
} else {
incrementsKnown = false;
}
// Handle locations
double[][] tmp1 = storedLocations;
storedLocations = locations;
locations = tmp1;
// Handle precision
precision = storedPrecision;
}
@Override
public void acceptState() {
if (storedIncrements != null) {
for (int j = 0; j < locationCount; ++j) {
increments[j][updatedLocation] = increments[updatedLocation][j];
}
}
}
@Override
public void getGradient(double[] location) {
throw new RuntimeException("Not yet implemented.");
}
@Override
public void makeDirty() {
sumOfIncrementsKnown = false;
incrementsKnown = false;
}
protected void computeSumOfSquaredResiduals() {
final double oneOverSd = Math.sqrt(precision);
final double scale = 0.5 * precision;
// OLD
sumOfIncrements = 0.0;
for (int i = 0; i < locationCount; i++) {
for (int j = 0; j < locationCount; j++) {
double distance = calculateDistance(locations[i], locations[j]);
double residual = distance - observations[i][j];
double increment = residual * residual;
if (isLeftTruncated) {
increment = scale * increment;
if (i != j) {
increment += computeTruncation(distance, precision, oneOverSd);
// increment += computeTruncation(Math.sqrt(residual * residual), precision, oneOverSd); // OLD .. believed incorrect
}
}
increments[i][j] = increment;
// increments[j][i] = increment; // Do not write transposed values
sumOfIncrements += increment;
}
}
sumOfIncrements /= 2;
incrementsKnown = true;
sumOfIncrementsKnown = true;
}
protected void updateSumOfSquaredResiduals() {
final double oneOverSd = Math.sqrt(precision);
final double scale = 0.5 * precision;
double delta = 0.0;
final int i = updatedLocation;
storedIncrements = new double[locationCount];
System.arraycopy(increments[i], 0, storedIncrements, 0, locationCount);
for (int j = 0; j < locationCount; j++) {
double distance = calculateDistance(locations[i], locations[j]);
double residual = distance - observations[i][j];
double increment = residual * residual;
if (isLeftTruncated) {
increment = scale * increment;
if (i != j) {
increment += computeTruncation(distance, precision, oneOverSd);
// increment += computeTruncation(Math.sqrt(residual * residual), precision, oneOverSd); // OLD .. believed incorrect
}
}
delta += increment - increments[i][j];
increments[i][j] = increment;
// increments[j][i] = increment; // Do not write transposed values
}
sumOfIncrements += delta;
}
protected double calculateDistance(double[] X, double[] Y) {
double sum = 0.0;
for (int i = 0; i < embeddingDimension; i++) {
double difference = X[i] - Y[i];
sum += difference * difference;
}
return Math.sqrt(sum);
}
protected double computeTruncation(double mean, double precision, double oneOverSd) {
return NormalDistribution.standardCDF(mean * oneOverSd, true); // Should be standardCDF(mean / sd, true);
}
private int embeddingDimension;
private boolean isLeftTruncated = false;
private int locationCount;
private int observationCount;
private double precision;
private double storedPrecision;
private int updatedLocation = -1;
private double[][] observations;
private double[][] locations;
private double[][] storedLocations;
private boolean incrementsKnown = false;
private boolean sumOfIncrementsKnown = false;
private double[][] increments;
private double[] storedIncrements;
private double sumOfIncrements;
private double storedSumOfIncrements;
private static boolean REPORT_ROUNDOFF = false;
}