//Copyright (C) 2011 Zeno Gantner
//
//This file is part of MyMediaLite.
//
//MyMediaLite is free software: you can redistribute it and/or modify
//it under the terms of the GNU General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//MyMediaLite is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
//GNU General Public License for more details.
//
//You should have received a copy of the GNU General Public License
//along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>.
//
package org.mymedialite.eval;
import java.util.HashMap;
import org.mymedialite.IIterativeModel;
import org.mymedialite.data.IRatings;
import org.mymedialite.data.ISplit;
import org.mymedialite.data.RatingCrossValidationSplit;
import org.mymedialite.ratingprediction.RatingPredictor;
/**
* Cross-validation for rating prediction.
* @version 2.03
*/
public class RatingsCrossValidation {
// Prevent instantiation.
private RatingsCrossValidation() { }
/**
* Evaluate on the folds of a dataset split.
* @param recommender a rating predictor
* @param num_folds the number of folds
* @param compute_fit if set to true measure fit on the training data as well
* @param show_results if set to true to print results to STDERR
* @return a dictionary containing the average results over the different folds of the split
* @throws Exception
*/
public static RatingPredictionEvaluationResults doCrossValidation(RatingPredictor recommender,
Integer num_folds,
Boolean compute_fit,
Boolean show_results) throws Exception {
if(num_folds == null) num_folds = 5;
if(compute_fit == null) compute_fit = false;
if(show_results == null) show_results = false;
RatingCrossValidationSplit split = new RatingCrossValidationSplit(recommender.getRatings(), num_folds);
return doCrossValidation(recommender, split, compute_fit, show_results);
}
/**
* Evaluate on the folds of a dataset split.
* @param recommender a rating predictor
* @param split a rating dataset split
* @param compute_fit if set to true measure fit on the training data as well
* @param show_results set to true to print results to STDERR
* @return a dictionary containing the average results over the different folds of the split
* @throws Exception
*/
public static RatingPredictionEvaluationResults doCrossValidation(
RatingPredictor recommender,
ISplit<IRatings> split,
Boolean compute_fit,
Boolean show_results) throws Exception {
if(compute_fit == null) compute_fit = false;
if(show_results == null) show_results = false;
RatingPredictionEvaluationResults avg_results = new RatingPredictionEvaluationResults();
for(int i = 0; i < split.numberOfFolds(); i++) {
try {
RatingPredictor split_recommender = recommender.clone(); // to avoid changes : recommender
split_recommender.setRatings(split.train().get(i));
split_recommender.train();
HashMap<String, Double> fold_results = Ratings.evaluate(split_recommender, split.test().get(i));
if (compute_fit)
fold_results.put("fit", new Double(Ratings.computeFit(split_recommender)));
for (String key : fold_results.keySet())
if (avg_results.containsKey(key))
avg_results.put(key, avg_results.get(key) + fold_results.get(key));
else
avg_results.put(key, fold_results.get(key));
if (show_results)
System.out.println("fold " + i + " " + fold_results);
} catch (Exception e) {
System.err.println("===> ERROR: " + e.getMessage());
throw e;
}
}
for (String key : Ratings.getMeasures()) {
avg_results.put(key, avg_results.get(key) / split.numberOfFolds());
}
return avg_results;
}
/**
* Evaluate an iterative recommender on the folds of a dataset split, display results on STDOUT.
* @param recommender a rating predictor
* @param num_folds the number of folds
* @param max_iter the maximum number of iterations
* @param find_iter the report interval
* @throws Exception
*/
public static void doIterativeCrossValidation(RatingPredictor recommender, int num_folds, int max_iter, Integer find_iter) throws Exception {
RatingCrossValidationSplit split = new RatingCrossValidationSplit(recommender.getRatings(), num_folds);
doIterativeCrossValidation(recommender, split, max_iter, find_iter);
}
/**
* Evaluate an iterative recommender on the folds of a dataset split, display results on STDOUT.
* @param recommender a rating predictor
* @param split a rating dataset split
* @param max_iter the maximum number of iterations
* @param find_iter the report interval
* @throws Exception
*/
public static void doIterativeCrossValidation(RatingPredictor recommender, ISplit<IRatings> split, int max_iter, Integer find_iter) throws Exception {
if(find_iter == null) find_iter = 1;
if (!(recommender instanceof IIterativeModel))
throw new IllegalArgumentException("recommender must be of type IIterativeModel");
RatingPredictor[] split_recommenders = new RatingPredictor[split.numberOfFolds()];
IIterativeModel[] iterative_recommenders = new IIterativeModel[split.numberOfFolds()];
// Initial training and evaluation
for(int i=0; i<split.numberOfFolds(); i++) {
try {
split_recommenders[i] = recommender.clone(); // to avoid changes : recommender
split_recommenders[i].setRatings(split.train().get(i));
split_recommenders[i].train();
iterative_recommenders[i] = (IIterativeModel) split_recommenders[i];
HashMap<String, Double> fold_results = Ratings.evaluate(split_recommenders[i], split.test().get(i));
System.out.println("fold " + i + " " + fold_results + " iteration " + iterative_recommenders[i].getNumIter());
} catch (Exception e) {
System.err.println("===> ERROR: " + e.getMessage());
throw e;
}
}
// Iterative training and evaluation
for (int it = iterative_recommenders[0].getNumIter() + 1; it <= max_iter; it++) {
for(int i=0; i<split.numberOfFolds(); i++) {
try {
iterative_recommenders[i].iterate();
if (it % find_iter == 0) {
HashMap<String, Double> fold_results = Ratings.evaluate(split_recommenders[i], split.test().get(i));
System.out.println("fold " + i + " " + fold_results + " iteration " + it);
}
} catch (Exception e) {
System.err.println("===> ERROR: " + e.getMessage());
throw e;
}
}
}
}
}