//Copyright (C) 2010, 2011 Zeno Gantner, Chris Newell
//
//This file is part of MyMediaLite.
//
//MyMediaLite is free software: you can redistribute it and/or modify
//it under the terms of the GNU General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//MyMediaLite is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
//GNU General Public License for more details.
//
//You should have received a copy of the GNU General Public License
//along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>.
package org.mymedialite.hyperparameter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.mymedialite.data.IRatings;
import org.mymedialite.data.ISplit;
import org.mymedialite.data.RatingsSimpleSplit;
import org.mymedialite.eval.Ratings;
import org.mymedialite.eval.RatingsCrossValidation;
import org.mymedialite.ratingprediction.BiasedMatrixFactorization;
import org.mymedialite.ratingprediction.MatrixFactorization;
import org.mymedialite.ratingprediction.RatingPredictor;
import org.mymedialite.ratingprediction.UserItemBaseline;
import org.mymedialite.util.Recommender;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.jet.math.Functions;
import cern.jet.math.PlusMult;
/**
* Nealder-Mead algorithm for finding suitable hyperparameters.
* @version 2.03
*/
public class NelderMead {
// Prevent instantiation.
private NelderMead() {}
// TODO avoid negative values e.g. for regularization ...
// TODO make configurable
static double alpha = 1.0;
static double gamma = 2.0;
static double rho = 0.5;
static double sigma = 0.5;
static double num_it = 50;
static double split_ratio = 0.2;
static String createConfigString(List<String> hp_names, double[] hp_values) {
String hp_String = "";
for (int i = 0; i < hp_names.size(); i++)
hp_String += " " + hp_names.get(i) + "=" + hp_values[i];
return hp_String;
}
static double run(RatingPredictor recommender, ISplit<IRatings> split, String hp_String, String evaluation_measure) throws Exception {
Recommender.configure(recommender, hp_String);
double result = RatingsCrossValidation.doCrossValidation(recommender, split, null, null).get(evaluation_measure);
System.err.println("Nelder-Mead: " + hp_String + ": " + result);
return result;
}
static DoubleMatrix1D computeCenter(Map<String, Double> results, Map<String, DoubleMatrix1D> hp_values) {
if (hp_values.size() == 0)
throw new IllegalArgumentException("need at least one vector to build center");
DoubleMatrix1D center = new DenseDoubleMatrix1D(hp_values.values().iterator().next().size());
for (String key : results.keySet())
center.assign(hp_values.get(key), Functions.plus);
// TODO check this is right - dividing by size - 1 instead of size?
center = new DenseDoubleMatrix1D(center.size()).assign(center, PlusMult.plusDiv(hp_values.size() - 1));
return center;
}
/**
* Find best hyperparameter (according to an error measure) using Nelder-Mead search.
* @param error_measure an error measure (lower is better)
* @param recommender a rating predictor (will be set to best hyperparameter combination)
* @return the estimated error of the best hyperparameter combination
* @throws Exception
*/
public static double findMinimum(String error_measure, RatingPredictor recommender) throws Exception {
ISplit<IRatings> split = new RatingsSimpleSplit(recommender.getRatings(), split_ratio);
//ISplit<IRatings> split = new RatingCrossValidationSplit(recommender.getRatings(), 5);
List<String> hp_names;
List<DoubleMatrix1D> initial_hp_values;
//TODO manage this via reflection?
if (recommender instanceof UserItemBaseline) {
hp_names = Arrays.asList(new String[] { "reg_u", "reg_i" } );
initial_hp_values = new ArrayList<DoubleMatrix1D>();
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 25, 10 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 10, 25 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 2, 5 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 5, 2 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 1, 4 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 4, 1 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 3, 3 } ));
} else if (recommender instanceof BiasedMatrixFactorization) {
hp_names = Arrays.asList(new String[] { "regularization", "bias_reg" } );
initial_hp_values = new ArrayList<DoubleMatrix1D>();
// TODO reg_u and reg_i (in a second step?)
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 0.1, 0 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 0.01, 0 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 0.0001, 0 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 0.00001, 0 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 0.1, 0.0001 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 0.01, 0.0001 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 0.0001, 0.0001 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 0.00001, 0.0001 } ));
} else if (recommender instanceof MatrixFactorization) {
// TODO normal interval search could be more efficient
hp_names = Arrays.asList(new String[] { "regularization" } );
initial_hp_values = new ArrayList<DoubleMatrix1D>();
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 0.1 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 0.01 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 0.0001 } ));
initial_hp_values.add( new DenseDoubleMatrix1D( new double[] { 0.00001 } ));
} else {
//TODO kNN-based methods
throw new IllegalArgumentException("Not prepared for type " + recommender.getClass().getName());
}
return findMinimum(error_measure, hp_names, initial_hp_values, recommender, split);
}
/**
* Find the the parameters resulting in the minimal results for a given evaluation measure.
* The recommender will be set to the best parameter value after calling this method.
* @param evaluation_measure the name of the evaluation measure
* @param hp_names the names of the hyperparameters to optimize
* @param initial_hp_values the values of the hyperparameters to try out first
* @param recommender the recommender
* @param split the dataset split to use
* @return the best (lowest) average value for the hyperparameter
* @throws Exception
*/
public static double findMinimum(
String evaluation_measure,
List<String> hp_names,
List<DoubleMatrix1D> initial_hp_values,
RatingPredictor recommender, // TODO make more general?
ISplit<IRatings> split) throws Exception {
Map<String, Double> results = new HashMap<String, Double>();
Map<String, DoubleMatrix1D> hp_vectors = new HashMap<String, DoubleMatrix1D>();
// Initialize
for (DoubleMatrix1D hp_values : initial_hp_values) {
String hp_String = createConfigString(hp_names, hp_values.toArray());
results.put(hp_String, run(recommender, split, hp_String, evaluation_measure));
hp_vectors.put(hp_String, hp_values);
}
List<String> keys;
for (int i = 0; i < num_it; i++) {
if (results.size() != hp_vectors.size())
throw new Exception(results.size() + " vs. " + hp_vectors.size());
keys = new ArrayList<String>(results.keySet());
Collections.sort(keys, new ResultsComparator(results));
String min_key = keys.get(0);
String max_key = keys.get(keys.size() - 1);
System.err.println("Nelder-Mead: iteration " + i + " (" + results.get(min_key) + ")");
DoubleMatrix1D worst_vector = hp_vectors.get(max_key);
Double worst_result = results.get(max_key);
hp_vectors.remove(max_key);
results.remove(max_key);
// Compute center
DoubleMatrix1D center = computeCenter(results, hp_vectors);
// Reflection
//Console.Error.WriteLine("ref");
//DoubleMatrix1D reflection = center + alpha * (center - worst_vector);
DoubleMatrix1D diffr = center.assign(worst_vector, PlusMult.minusMult(1.0));
DoubleMatrix1D reflection = center.assign(diffr, PlusMult.plusMult(alpha));
String ref_String = createConfigString(hp_names, reflection.toArray());
double ref_result = run(recommender, split, ref_String, evaluation_measure);
if (results.get(min_key) <= ref_result && ref_result < Collections.max(results.values())) {
results.put(ref_String, ref_result);
hp_vectors.put(ref_String, reflection);
continue;
}
// Expansion
if (ref_result < results.get(min_key)) {
//Console.Error.WriteLine("exp");
DoubleMatrix1D diffe = center.assign(worst_vector, PlusMult.minusMult(1.0));
DoubleMatrix1D expansion = center.assign(diffe, PlusMult.plusMult(gamma));
String exp_String = createConfigString(hp_names, expansion.toArray());
double exp_result = run(recommender, split, exp_String, evaluation_measure);
if (exp_result < ref_result) {
results.put(exp_String, exp_result);
hp_vectors.put(exp_String, expansion);
} else {
results.put(ref_String, ref_result);
hp_vectors.put(ref_String, reflection);
}
continue;
}
// Contraction
//Console.Error.WriteLine("con");
DoubleMatrix1D diffc = center.assign(worst_vector, PlusMult.minusMult(1.0));
DoubleMatrix1D contraction = center.assign(diffc, PlusMult.plusMult(rho));
String con_String = createConfigString(hp_names, contraction.toArray());
double con_result = run(recommender, split, con_String, evaluation_measure);
if (con_result < worst_result) {
results.put(con_String, con_result);
hp_vectors.put(con_String, contraction);
continue;
}
// Reduction
//Console.Error.WriteLine("red");
DoubleMatrix1D best_vector = hp_vectors.get(min_key);
Double best_result = results.get(min_key);
hp_vectors.remove(min_key);
results.remove(min_key);
for (String key : new ArrayList<String>(results.keySet())) {
DoubleMatrix1D diffu = hp_vectors.get(key).assign(best_vector, PlusMult.minusMult(1.0));
DoubleMatrix1D reduction = hp_vectors.get(key).assign(diffu, PlusMult.plusMult(sigma));
String red_String = createConfigString(hp_names, reduction.toArray());
double red_result = run(recommender, split, red_String, evaluation_measure);
// Replace by reduced vector
results.remove(key);
hp_vectors.remove(key);
results.put(red_String, red_result);
hp_vectors.put(red_String, reduction);
}
results.put(min_key, best_result);
hp_vectors.put(min_key, best_vector);
results.put(max_key, worst_result);
hp_vectors.put(max_key, worst_vector);
}
keys = new ArrayList<String>(results.keySet());
Collections.sort(keys, new ResultsComparator(results));
// Set to best hyperparameter values
Recommender.configure(recommender, keys.get(0));
return results.get(keys.get(0));
}
static private class ResultsComparator implements Comparator<String> {
private Map<String, Double> results;
public ResultsComparator(Map<String, Double> results) {
this.results = results;
}
@Override
public int compare(String k1, String k2) {
return results.get(k1).compareTo(results.get(k2));
}
}
}