/*
* MMAlgorithm.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.multidimensionalscaling.mm;
import dr.inference.multidimensionalscaling.MultiDimensionalScalingLikelihood;
import dr.evomodel.continuous.FullyConjugateMultivariateTraitLikelihood;
import dr.xml.*;
/**
* Created by msuchard on 12/15/15.
*/
public abstract class MMAlgorithm {
public static final double DEFAULT_TOLERANCE = 1E-1;
public static final int DEFAULT_MAX_ITERATIONS = 1000;
public double[] findMode(final double[] startingValue) throws NotConvergedException {
return findMode(startingValue, DEFAULT_TOLERANCE, DEFAULT_MAX_ITERATIONS);
}
private void copyDifference(double[] dest, final double[] x, final double[] y) {
final int len = dest.length;
for (int i = 0; i < len; ++i) {
dest[i] = x[i] - y[i];
}
}
public double[] findMode(final double[] startingValue, final double tolerance,
final int maxIterations) throws NotConvergedException {
if (DEBUG) {
System.err.println("Starting findMode with " + tolerance + " " + maxIterations);
}
double[] buffer1 = new double[startingValue.length];
double[] buffer2 = new double[startingValue.length];
double[] previous = buffer1;
double[] x = buffer2;
System.arraycopy(startingValue, 0, x, 0, startingValue.length);
iteration = 0;
if (qnQ == 0) {
// Standard MM
do {
// Move x -> previous
double[] tmp = previous; previous = x; x = tmp;
if (DEBUG) {
System.err.println("Current: " + printArray(previous));
}
mmUpdate(previous, x);
++iteration;
if (DEBUG) {
System.err.println("Finished iteration " + iteration);
}
} while (convergenceCriterion(x, previous) > tolerance && iteration < maxIterations);
} else {
// Quasi-Newton acceleration
final int J = startingValue.length;
double[][] secantsU = new double[1][J];
double[][] secantsV = new double[1][J];
double[] Fx = new double[J];
double[] C = new double[J];
double[] x0 = new double[J];
// Fill initial secants
int countU = 0;
int countV = 0;
for (int q = 0; q < qnQ; ++q) {
double[] tmp = previous; previous = x; x = tmp;
mmUpdate(previous, x);
++iteration;
if (countU == 0) {
copyDifference(secantsU[countU], x, previous);
++countU;
} else if (countV < qnQ - 1) {
copyDifference(secantsU[countU], x, previous);
System.arraycopy(secantsU[countU], 0, secantsV[countV], 0, J);
++countU;
++countV;
} else {
copyDifference(secantsV[countV], x, previous);
++countV;
}
}
int newestSecant = qnQ - 1;
int previousSecant = newestSecant - 1;
boolean done = false;
while (!done) {
System.arraycopy(x, 0, x0, 0, J);
// 2 cycles for each QN step
double[] tmp = previous; previous = x; x = tmp;
mmUpdate(previous, x);
++iteration;
copyDifference(secantsU[newestSecant], x, previous);
System.arraycopy(x, 0, Fx, 0, J); // TODO Remove Fx?
tmp = previous; previous = x; x = tmp;
mmUpdate(previous, x);
++iteration;
copyDifference(secantsV[newestSecant], x, previous);
// Do QN approximation here
// auto M = secantsU.transpose() * (secantsU - secantsV);
// auto Minv = M.inverse();
// auto A = secantsU.transpose() * secantsU.col(newestSecant);
// auto B = Minv * A;
// auto C = secantsV * B;
// VectorXd xqn = Fx + C;
double M = 0;
for (int j = 0; j < J; ++j) {
M += secantsU[0][j] * (secantsU[0][j] - secantsV[0][j]);
}
double Minv = 1.0 / M;
double A = 0;
for (int j = 0; j < J; ++j) {
A += secantsU[0][j] * secantsU[newestSecant][j];
}
double B = Minv * A;
for (int j = 0; j < J; ++j) {
C[j] = secantsV[0][j] * B;
}
// New step
for (int j = 0; j < J; ++j) {
x[j] = Fx[j] + C[j];
}
x[1] = 0.0; // Fixed point
// Get ready for next secant-pair
previousSecant = newestSecant;
newestSecant = (newestSecant + 1) % qnQ;
done = (convergenceCriterion(x, x0) < tolerance || iteration > maxIterations);
// throw new RuntimeException("A");
}
}
System.err.println("Finished in " + iteration + " iterations.");
if (iteration >= maxIterations) {
throw new NotConvergedException();
}
if (DEBUG) {
System.err.println("Final : " + printArray(x));
}
// throw new RuntimeException("out");
return x;
}
static private String format = "%5.3e";
protected String printArray(double[] x) {
StringBuilder sb = new StringBuilder();
sb.append(String.format(format, x[0]));
for (int i = 1; i < x.length; ++i) {
sb.append(", ").append(String.format(format, x[i]));
}
return sb.toString();
}
protected abstract void mmUpdate(final double[] current, double[] next);
private double convergenceCriterion(final double[] current, final double[] previous) {
double norm = 0.0;
for (int i = 0; i < current.length; ++i) {
norm += (current[i] - previous[i]) * (current[i] - previous[i]);
}
double value = Math.sqrt(norm);
if (DEBUG) {
System.err.println("Convergence = " + value);
}
return value;
}
class NotConvergedException extends Exception {
// Nothing interesting
}
final int qnQ = 0;
int iteration;
private static final boolean DEBUG = true;
private static final boolean PROGRESS = true;
}