/* * Copyright (c) 2012 Diamond Light Source Ltd. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html */ package uk.ac.diamond.scisoft.analysis.optimize; import java.util.Random; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * The implementation of the Nelder-Mead optimisation for the fitting routines. * * It is variously known as the down-hill simplex or amoeba method */ public class NelderMead extends AbstractOptimizer { double accuracy = 0.1; public NelderMead() { } /** * @param accuracy */ public NelderMead(double accuracy) { this.accuracy = accuracy; } public void setAccuracy(double accuracy) { this.accuracy = accuracy; } /** * Setup the logging facilities */ private static final Logger logger = LoggerFactory.getLogger(NelderMead.class); private double alpha = 1.0; private double gamma = 2.0; private double phi = 0.5; private double sigma = 0.5; private double startingSpread = 1.0; private int maxIterations = 2000; private Random random = new Random(0); /** * @return Returns the alpha. */ public double getAlpha() { return alpha; } /** * @param alpha * The alpha to set. */ public void setAlpha(double alpha) { this.alpha = alpha; } /** * @return Returns the gamma. */ public double getGamma() { return gamma; } /** * @param gamma * The gamma to set. */ public void setGamma(double gamma) { this.gamma = gamma; } /** * @return Returns the phi. */ public double getPhi() { return phi; } /** * @param phi * The phi to set. */ public void setPhi(double phi) { this.phi = phi; } /** * @return Returns the sigma. */ public double getSigma() { return sigma; } /** * @param sigma * The sigma to set. */ public void setSigma(double sigma) { this.sigma = sigma; } /** * @return Returns the startingSpread. */ public double getStartingSpread() { return startingSpread; } /** * @param startingSpread * The startingSpread to set. */ public void setStartingSpread(double startingSpread) { this.startingSpread = startingSpread; } @Override void internalOptimize() { double[] best = optimise(getParameterValues(), accuracy); setParameterValues(best); } /** * The main optimisation method * * @param parameters * @param finishCriteria * @return a double array of the parameters for the optimisation */ private double[] optimise(double[] parameters, double finishCriteria) { double[] solution = parameters; // TODO make this more adaptive than just running 5 times, should work to a tolerance // However 5 times gives a reasonable solution for the time being without too much // overhead. for (int p = 0; p < 5; p++) { boolean ok = true; int collapseCount = 0; while (ok && collapseCount < 5) { // this outer loop handles collapsing simplexes Simplex simplex = new Simplex(solution, startingSpread); simplex.generateStartingFitnesses(); int iterationCount = 0; do { simplex.iterate(); } while (simplex.improving(finishCriteria) && ++iterationCount < maxIterations ); solution = simplex.getBestSolution(); ok = simplex.hasCollapsed(); collapseCount += 1; } if (collapseCount > 0) { logger.info("Error, the NelderMead simplex has collapsed, this minimisation may be flawed"); } } return solution; } private class Simplex { private boolean collapsed = false; double[][] points; double[] fitnesses; double oldBest = 0; double oldSpread = 0; int reduceCount = 0; int contractCount = 0; int extendedCount = 0; int reflectedCount = 0; @SuppressWarnings("unused") public int getReduceCount() { return reduceCount; } @SuppressWarnings("unused") public int getContractCount() { return contractCount; } @SuppressWarnings("unused") public int getExtendedCount() { return extendedCount; } @SuppressWarnings("unused") public int getReflectedCount() { return reflectedCount; } public Simplex(double[] parameters, double startingSpread) { points = new double[parameters.length + 1][parameters.length]; initialise(parameters, startingSpread); } public boolean hasCollapsed() { return collapsed; } public boolean improving(double finishCriteria) { int[] sortedValues = sortFitnesses(); double best = fitnesses[sortedValues[0]]; double spread = Math.abs(fitnesses[sortedValues[0]] - fitnesses[sortedValues[sortedValues.length - 1]]); if (spread < finishCriteria) { oldBest = best; oldSpread = spread; return false; } if ((best == oldBest) && (spread == oldSpread)) { collapsed = true; // this collapse is fixed by externally restarting search with bigger spread return false; } oldBest = best; oldSpread = spread; return true; } public double[] getBestSolution() { int[] orderedValues = sortFitnesses(); return points[orderedValues[0]]; } public void iterate() { int[] orderedValues = sortFitnesses(); double[] cog = calculateCentreOfGravityIgnoringWorstValue(orderedValues); double[] reflectedPoint = calculateReflectedPoint(cog, orderedValues); double reflectedPointFitness = calculateResidual(reflectedPoint); // if the reflected point is better than the best point if (reflectedPointFitness < fitnesses[orderedValues[0]]) { double[] extendedPoint = calculateExtendedPoint(cog, orderedValues); double extendedPointFitness = calculateResidual(extendedPoint); // if the extended point is better than the reflected point if (extendedPointFitness < reflectedPointFitness) { replaceWorstPointWith(extendedPoint, extendedPointFitness, orderedValues); return; } replaceWorstPointWith(reflectedPoint, reflectedPointFitness, orderedValues); return; } // if the value is worse than the next worst double nextWorstFitness = fitnesses[orderedValues[orderedValues.length - 2]]; if (reflectedPointFitness > nextWorstFitness) { double[] contractedPoint = calculateContractedPoint(cog, orderedValues); double contractedPointFitness = calculateResidual(contractedPoint); // if this is better than the worst point if (contractedPointFitness < fitnesses[orderedValues[orderedValues.length - 1]]) { replaceWorstPointWith(contractedPoint, contractedPointFitness, orderedValues); return; } reducePoints(orderedValues); return; } // if the point is amongst the other points, then simple replace the // worst point with it replaceWorstPointWith(reflectedPoint, reflectedPointFitness, orderedValues); return; } private void reducePoints(int[] orderedValues) { reduceCount++; for (int i = 1; i < orderedValues.length; i++) { for (int j = 0; j < points[i].length; j++) { points[orderedValues[i]][j] = points[orderedValues[0]][j] + sigma * (points[orderedValues[i]][j] - points[orderedValues[0]][j]); } fitnesses[orderedValues[i]] = calculateResidual(points[orderedValues[i]]); } } private double[] calculateContractedPoint(double[] cog, int[] orderedValues) { contractCount++; double[] contractedPoint = new double[cog.length]; int worstPoint = orderedValues[orderedValues.length - 1]; for (int i = 0; i < contractedPoint.length; i++) { contractedPoint[i] = cog[i] - (phi * (cog[i] - points[worstPoint][i])); } return contractedPoint; } private void replaceWorstPointWith(double[] replacingPoint, double replacingPointFitness, int[] orderedValues) { int worstPoint = orderedValues[orderedValues.length - 1]; for (int i = 0; i < replacingPoint.length; i++) { points[worstPoint][i] = replacingPoint[i]; } fitnesses[worstPoint] = replacingPointFitness; } private double[] calculateExtendedPoint(double[] cog, int[] orderedValues) { extendedCount++; double[] extendedPoint = new double[cog.length]; for (int i = 0; i < extendedPoint.length; i++) { extendedPoint[i] = cog[i] + gamma * (cog[i] - points[orderedValues[orderedValues.length - 1]][i]); } return extendedPoint; } private double[] calculateReflectedPoint(double[] cog, int[] orderedValues) { reflectedCount++; double[] reflectedPoint = new double[cog.length]; int worstPoint = orderedValues[orderedValues.length - 1]; for (int i = 0; i < reflectedPoint.length; i++) { reflectedPoint[i] = cog[i] + (alpha * (cog[i] - points[worstPoint][i])); } return reflectedPoint; } private double[] calculateCentreOfGravityIgnoringWorstValue(int[] orderedValues) { double[] cog = new double[points[0].length]; for (int i = 0; i < cog.length; i++) { cog[i] = 0; } for (int i = 0; i < orderedValues.length - 1; i++) { for (int j = 0; j < cog.length; j++) { cog[j] += points[orderedValues[i]][j]; } } for (int i = 0; i < cog.length; i++) { cog[i] = cog[i] / (orderedValues.length - 1); } return cog; } private int[] sortFitnesses() { int[] sortedPositions = new int[fitnesses.length]; bubbleSort(sortedPositions); return sortedPositions; } private void bubbleSort(int[] sortedPositions) { boolean done = false; for (int i = 0; i < sortedPositions.length; i++) { sortedPositions[i] = i; } while (!done) { done = true; for (int i = 0; i < sortedPositions.length - 1; i++) { if (fitnesses[sortedPositions[i]] > fitnesses[sortedPositions[i + 1]]) { int temp = sortedPositions[i]; sortedPositions[i] = sortedPositions[i + 1]; sortedPositions[i + 1] = temp; done = false; } } } } public void generateStartingFitnesses() { // TODO replace fitnesses with tree map <Double,Integer>??? // fitnesses.put(eval(points[i]), i) fitnesses = new double[points.length]; for (int i = 0; i < fitnesses.length; i++) { fitnesses[i] = calculateResidual(points[i]); } // initialise stored best values int[] sortedValues = sortFitnesses(); oldBest = fitnesses[sortedValues[0]]; oldSpread = Math.abs(fitnesses[sortedValues[0]] - fitnesses[sortedValues[sortedValues.length - 1]]); } private void initialise(double[] parameters, double startingSpread) { for (int i = 0; i < points.length; i++) { double[] randomUnitVector = createRandomUnitVector(parameters.length); for (int j = 0; j < parameters.length; j++) { points[i][j] = parameters[j] + randomUnitVector[j] * startingSpread; } } } private double[] createRandomUnitVector(int size) { double[] unitVector = createRandomVector(size); double length = getVectorlength(unitVector); for (int i = 0; i < unitVector.length; i++) { unitVector[i] = unitVector[i] / length; } return unitVector; } private double getVectorlength(double[] vector) { double sum = 0.0; for (int i = 0; i < vector.length; i++) { sum += vector.length * vector.length; } return Math.sqrt(sum); } private double[] createRandomVector(int size) { double[] result = new double[size]; for (int i = 0; i < size; i++) { result[i] = random.nextDouble() - 0.5; } return result; } } }