import java.io.IOException; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Calendar; import java.util.Collections; import java.util.Date; import java.util.HashMap; import java.util.List; import org.mymedialite.IItemAttributeAwareRecommender; import org.mymedialite.IItemRelationAwareRecommender; import org.mymedialite.IIterativeModel; import org.mymedialite.IUserAttributeAwareRecommender; import org.mymedialite.IUserRelationAwareRecommender; import org.mymedialite.data.EntityMapping; import org.mymedialite.data.Extensions; import org.mymedialite.data.IEntityMapping; import org.mymedialite.data.IRatings; import org.mymedialite.data.ITimedRatings; import org.mymedialite.data.IdentityMapping; import org.mymedialite.data.RatingCrossValidationSplit; import org.mymedialite.data.RatingType; import org.mymedialite.data.RatingsChronologicalSplit; import org.mymedialite.data.RatingsSimpleSplit; import org.mymedialite.datatype.SparseBooleanMatrix; import org.mymedialite.eval.ItemsCrossValidation; import org.mymedialite.eval.RatingPredictionEvaluationResults; import org.mymedialite.eval.Ratings; import org.mymedialite.eval.RatingsCrossValidation; import org.mymedialite.eval.RatingsOnline; import org.mymedialite.hyperparameter.NelderMead; import org.mymedialite.io.AttributeData; import org.mymedialite.io.ItemDataFileFormat; import org.mymedialite.io.Model; import org.mymedialite.io.MovieLensRatingData; import org.mymedialite.io.RatingData; import org.mymedialite.io.RatingFileFormat; import org.mymedialite.io.RelationData; import org.mymedialite.io.StaticRatingData; import org.mymedialite.io.TimedRatingData; import org.mymedialite.ratingprediction.IIncrementalRatingPredictor; import org.mymedialite.ratingprediction.RatingPredictor; import org.mymedialite.ratingprediction.TimeAwareRatingPredictor; import org.mymedialite.util.Handlers; import org.mymedialite.util.Memory; import org.mymedialite.util.Recommender; import org.mymedialite.util.Utils; /** * Rating prediction program, see usage() method for more information. * @version 2.03 */ public class RatingPrediction { static final String VERSION = "2.03"; static final SimpleDateFormat dateFormat = new SimpleDateFormat(); // Data sets static IRatings training_data; static IRatings test_data; // Recommenders static RatingPredictor recommender = null; // ID mapping objects static IEntityMapping user_mapping = new EntityMapping(); static IEntityMapping item_mapping = new EntityMapping(); static IEntityMapping attribute_mapping = new EntityMapping(); // User and item attributes static SparseBooleanMatrix user_attributes; static SparseBooleanMatrix item_attributes; // Time statistics static ArrayList<Double> training_time_stats = new ArrayList<Double>(); static ArrayList<Double> fit_time_stats = new ArrayList<Double>(); static ArrayList<Double> eval_time_stats = new ArrayList<Double>(); static ArrayList<Double> rmse_eval_stats = new ArrayList<Double>(); // Command line parameters static String training_file; static String test_file; static String save_model_file = null; static String load_model_file = null; static String user_attributes_file; static String item_attributes_file; static String user_relations_file; static String item_relations_file; static String prediction_file; static boolean compute_fit; static RatingFileFormat file_format = RatingFileFormat.DEFAULT; static RatingType rating_type = RatingType.DOUBLE; static int cross_validation; static boolean show_fold_results = false; static double test_ratio; static String chronological_split = null; static double chronological_split_ratio = -1; static Date chronological_split_time = new Date(0); static int find_iter; static boolean online_eval = false; static void showVersion() { System.out.println("MyMediaLite Rating Prediction " + VERSION); System.out.println("Copyright (C) 2010 Zeno Gantner, Steffen Rendle"); System.out.println("Copyright (C) 2011 Zeno Gantner, Chris Newell"); System.out.println("This is free software; see the source for copying conditions. There is NO"); System.out.println("warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE."); System.exit(0); } private static class ErrorHandler implements Recommender.ErrorHandler { public void reportError(String message) { usage(message); } } static void usage(String message) { System.out.println(message); System.out.println(); usage(-1); } static void usage(int exit_code) { System.out.println("MyMediaLite rating prediction " + VERSION); System.out.println(" usage: rating_prediction --training-file=FILE --recommender=METHOD [OPTIONS]"); System.out.println(" method ARGUMENTS have the form name=value"); System.out.println(); System.out.println(" general OPTIONS:\n" + " --recommender=METHOD set recommender method (default: BiasedMatrixFactorization)\n" + " --recommender-options=OPTIONS use OPTIONS as recommender options\n" + " --help display this usage information and exit\n" + " --version display version information and exit\n" + " --random-seed=N initialize the random number generator with N\n" + " --rating-type=float|byte|double store ratings internally as floats or bytes or doubles (default)\n" + "\n" + " files:\n" + " --training-file=FILE read training data from FILE\n" + " --test-file=FILE read test data from FILE\n" + " --file-format=movielens_1m|kddcup_2011|ignore_first_line|default\n" + " --data-dir=DIR load all files from DIR\n" + " --user-attributes=FILE file containing user attribute information, 1 tuple per line\n" + " --item-attributes=FILE file containing item attribute information, 1 tuple per line\n" + " --user-relations=FILE file containing user relation information, 1 tuple per line\n" + " --item-relations=FILE file containing item relation information, 1 tuple per line\n" + " --save-model=FILE save computed model to FILE\n" + " --load-model=FILE load model from FILE\n" + "\n" + " prediction options:\n" + " --prediction-file=FILE write the rating predictions to FILE\n" + " --prediction-line=FORMAT format of the prediction line; {0}, {1}, {2} refer to user ID, item ID,\n" + " and predicted rating, respectively; default instanceof {0}\\t{1}\\t{2}\n" + "\n" + " evaluation options:\n" + " --cross-validation=K perform k-fold cross-validation on the training data\n" + " --show-fold-results show results for individual folds : cross-validation\n" + " --test-ratio=NUM use a ratio of NUM of the training data for evaluation (simple split)\n" + " --chronological-split=NUM|DATETIME use the last ratio of NUM of the training data ratings for evaluation,\n" + " or use the ratings from DATETIME on for evaluation (requires time information\n" + " in the training data)\n" + " --online-evaluation perform online evaluation (use every tested rating for incremental training)\n" + " --search-hp search for good hyperparameter values (experimental)\n" + " --compute-fit display fit on training data\n" + "\n" + " options for finding the right number of iterations (iterative methods)\n" + " --find-iter=N give out statistics every N iterations\n" + " --max-iter=N perform at most N iterations\n" + " --epsilon=NUM abort iterations if RMSE instanceof more than best result plus NUM\n" + " --rmse-cutoff=NUM abort if RMSE instanceof above NUM\n" + " --mae-cutoff=NUM abort if MAE instanceof above NUM\n"); System.exit(exit_code); } public static void main(String[] args) throws Exception { // Handlers for uncaught exceptions and interrupts Thread.setDefaultUncaughtExceptionHandler(new Handlers()); Runtime.getRuntime().addShutdownHook(new Thread() { @Override public void run() { displayStats(); } }); // Recommender arguments String method = null; String recommender_options = ""; // Help/version boolean show_help = false; boolean show_version = false; // Arguments for iteration search int max_iter = 100; double epsilon = 0; double rmse_cutoff = Double.MAX_VALUE; double mae_cutoff = Double.MAX_VALUE; // Data arguments String data_dir = ""; // Other arguments boolean search_hp = false; int random_seed = -1; String prediction_line = "{0}\t{1}\t{2}"; for(String arg : args) { int div = arg.indexOf("=") + 1; String name; String value; if(div > 0) { name = arg.substring(0, div); value = arg.substring(div); } else { name = arg; value = null; } // String-valued options if(name.equals("--training-file=")) training_file = value; else if(name.equals("--test-file=")) test_file = value; else if(name.equals("--recommender=")) method = value; else if(name.equals("--recommender-options=")) recommender_options += " " + value; else if(name.equals("--data-dir=")) data_dir = value; else if(name.equals("--user-attributes=")) user_attributes_file = value; else if(name.equals("--item-attributes=")) item_attributes_file = value; else if(name.equals("--user-relations=")) user_relations_file = value; else if(name.equals("--item-relations=")) item_relations_file = value; else if(name.equals("--save-model=")) save_model_file = value; else if(name.equals("--load-model=")) load_model_file = value; else if(name.equals("--prediction-file=")) prediction_file = value; else if(name.equals("--prediction-line=")) prediction_line = value; else if(name.equals("--chronological-split=")) chronological_split = value; // Integer-valued options else if(name.equals("--find-iter=")) find_iter = Integer.parseInt(value); else if(name.equals("--max-iter=")) max_iter = Integer.parseInt(value); else if(name.equals("--random-seed=")) random_seed = Integer.parseInt(value); else if(name.equals("--cross-validation=")) cross_validation = Integer.parseInt(value); // Double-valued options else if(name.equals("--epsilon=")) epsilon = Double.parseDouble(value); else if(name.equals("--rmse-cutoff=")) rmse_cutoff = Double.parseDouble(value); else if(name.equals("--mae-cutoff=")) mae_cutoff = Double.parseDouble(value); else if(name.equals("--test-ratio=")) test_ratio = Double.parseDouble(value); // Enum options else if(name.equals("--rating-type=")) rating_type = RatingType.valueOf(value); else if(name.equals("--file-format=")) file_format = RatingFileFormat.valueOf(value); // Boolean options else if(name.equals("--compute-fit")) compute_fit = true; else if(name.equals("--online-evaluation")) online_eval = true; else if(name.equals("--show-fold-results")) show_fold_results = true; else if(name.equals("--search-hp")) search_hp = true; else if(name.equals("--help")) show_help = true; else if(name.equals("--version")) show_version = true; } // ... some more command line parameter actions ... boolean no_eval = true; if (test_ratio > 0 || test_file != null || chronological_split != null) no_eval = false; if (show_version) showVersion(); if (show_help) usage(0); if (random_seed != -1) org.mymedialite.util.Random.initInstance(random_seed); // Set up recommender if (load_model_file != null) recommender = (RatingPredictor) Model.load(load_model_file); else if (method != null) recommender = Recommender.createRatingPredictor(method); else recommender = Recommender.createRatingPredictor("BiasedMatrixFactorization"); // In case something went wrong ... if (recommender == null && method != null) usage("Unknown rating prediction method: " + method); if (recommender == null && load_model_file != null) usage("Could not load model from file "+ load_model_file); checkParameters(); try { recommender = Recommender.configure(recommender, recommender_options, new ErrorHandler()); } catch (IllegalAccessException e) { System.err.println("Unable to instantiate recommender: " + recommender.toString()); System.exit(0); } // ID mapping objects if (file_format == RatingFileFormat.KDDCUP_2011) { user_mapping = new IdentityMapping(); item_mapping = new IdentityMapping(); } // Load all the data loadData(data_dir, user_attributes_file, item_attributes_file, user_relations_file, item_relations_file, !online_eval); System.out.println("Ratings range: " + recommender.getMinRating() + ", " + recommender.getMaxRating()); if (test_ratio > 0) { RatingsSimpleSplit split = new RatingsSimpleSplit(training_data, test_ratio); // TODO check training_data = split.train().get(0); recommender.setRatings(training_data); // TODO check test_data = split.test().get(0); System.out.println("Test ratio: " + test_ratio); } if (chronological_split != null) { RatingsChronologicalSplit split = chronological_split_ratio != -1 ? new RatingsChronologicalSplit((ITimedRatings) training_data, chronological_split_ratio) : new RatingsChronologicalSplit((ITimedRatings) training_data, chronological_split_time); training_data = split.train().get(0); recommender.setRatings(training_data); test_data = split.test().get(0); if (test_ratio != -1) System.out.println("Test ratio (chronological): " + chronological_split_ratio); else System.out.println("Split time:" + chronological_split_time); } System.out.print(Extensions.statistics(training_data, test_data, user_attributes, item_attributes, false)); if (find_iter != 0) { if (!(recommender instanceof IIterativeModel) ) usage("Only iterative recommenders (interface IIterativeModel) support --find-iter=N."); System.out.println("Recommender: " + recommender.toString()); if (cross_validation > 1) { RatingsCrossValidation.doIterativeCrossValidation(recommender, cross_validation, max_iter, find_iter); } else { IIterativeModel iterative_recommender = (IIterativeModel) recommender; if (load_model_file == null) recommender.train(); if (compute_fit) System.out.println("Fit " + Ratings.evaluate(recommender, training_data) + " iteration " + iterative_recommender.getNumIter()); System.out.println(Ratings.evaluate(recommender, test_data) + " iteration " + iterative_recommender.getNumIter()); for (int it = iterative_recommender.getNumIter() + 1; it <= max_iter; it++) { long start = Calendar.getInstance().getTimeInMillis(); iterative_recommender.iterate(); training_time_stats.add((double)(Calendar.getInstance().getTimeInMillis() - start) / 1000); if (it % find_iter == 0) { if (compute_fit) { start = Calendar.getInstance().getTimeInMillis(); System.out.println("Fit " + Ratings.evaluate(recommender, training_data) + " iteration " + it); fit_time_stats.add((double)(Calendar.getInstance().getTimeInMillis() - start) / 1000); } HashMap<String, Double> results = null; start = Calendar.getInstance().getTimeInMillis(); results = Ratings.evaluate(recommender, test_data); eval_time_stats.add((double)(Calendar.getInstance().getTimeInMillis() - start) / 1000); rmse_eval_stats.add(results.get("RMSE")); System.out.println(results + " iteration " + it); Model.save(recommender, save_model_file, it); if (prediction_file != null) org.mymedialite.ratingprediction.Extensions.writePredictions(recommender, test_data, prediction_file + "-it-" + it, user_mapping, item_mapping, prediction_line); if (epsilon > 0.0 && results.get("RMSE") - Collections.min(rmse_eval_stats) > epsilon) { System.out.println(results.get("RMSE") + " >> " + Collections.min(rmse_eval_stats)); System.out.println("Reached convergence on training/validation data after " + it + " iterations."); break; } if (results.get("RMSE") > rmse_cutoff || results.get("MAE") > mae_cutoff) { System.out.println("Reached cutoff after " + it + " iterations."); break; } } } // for } } else { long start = Calendar.getInstance().getTimeInMillis(); System.out.println("Recommender: " + recommender); if (load_model_file == null) { if (cross_validation > 1) { RatingPredictionEvaluationResults results = RatingsCrossValidation.doCrossValidation(recommender, cross_validation, compute_fit, show_fold_results); System.out.println(results); no_eval = true; } else { if (search_hp) { double result = NelderMead.findMinimum("RMSE", recommender); System.out.println("Estimated quality (on split): " + result); } recommender.train(); System.out.println("Training time: " + (double)(Calendar.getInstance().getTimeInMillis() - start) / 1000 + " seconds"); } } if (!no_eval) { start = Calendar.getInstance().getTimeInMillis(); if (online_eval) System.out.println(RatingsOnline.evaluateOnline(recommender, test_data)); else System.out.println(Ratings.evaluate(recommender, test_data)); System.out.println("Testing time: " + (double)(Calendar.getInstance().getTimeInMillis() - start) / 1000 + " seconds"); if (compute_fit) { System.out.print("Fit:"); start = Calendar.getInstance().getTimeInMillis(); System.out.print(Ratings.evaluate(recommender, training_data)); System.out.println(" fit time: " + (double)(Calendar.getInstance().getTimeInMillis() - start) / 1000); } if (prediction_file != null) { System.out.print("Predict:"); start = Calendar.getInstance().getTimeInMillis(); org.mymedialite.ratingprediction.Extensions.writePredictions(recommender, test_data, prediction_file, user_mapping, item_mapping, prediction_line); System.out.println(" prediction time " + (double)(Calendar.getInstance().getTimeInMillis() - start) / 1000); } } //System.out.println(); } Model.save(recommender, save_model_file); //displayStats(); } static void checkParameters() { if (online_eval && !(recommender instanceof IIncrementalRatingPredictor)) usage("Recommender " + recommender.getClass().getName() + " does not support incremental updates, which are necessary for an online experiment."); if (training_file == null && load_model_file == null) usage("Please provide either --training-file=FILE or --load-model=FILE."); if (cross_validation == 1) usage("--cross-validation=K requires K to be at least 2."); if (show_fold_results && cross_validation == 0) usage("--show-fold-results only works with --cross-validation=K."); if (cross_validation > 1 && test_ratio != 0) usage("--cross-validation=K and --test-ratio=NUM are mutually exclusive."); if (cross_validation > 1 && prediction_file != null) usage("--cross-validation=K and --prediction-file=FILE are mutually exclusive."); if (test_file == null && test_ratio == 0 && cross_validation == 0 && save_model_file == null && chronological_split == null) usage("Please provide either test-file=FILE, --test-ratio=NUM, --cross-validation=K, --chronological-split=NUM|DATETIME, or --save-model=FILE."); if (recommender instanceof IUserAttributeAwareRecommender && user_attributes_file == null) usage("Recommender expects --user-attributes=FILE."); if (recommender instanceof IItemAttributeAwareRecommender && item_attributes_file == null) usage("Recommender expects --item-attributes=FILE."); if (recommender instanceof IUserRelationAwareRecommender && user_relations_file == null) usage("Recommender expects --user-relations=FILE."); if (recommender instanceof IItemRelationAwareRecommender && user_relations_file == null) usage("Recommender expects --item-relations=FILE."); // handling of --chronological-split if (chronological_split != null) { try { chronological_split_ratio = Double.parseDouble(chronological_split); } catch(NumberFormatException e) { usage("Unable to parse chronological_split_ratio " + chronological_split_ratio + " as double"); } if (chronological_split_ratio == -1) try { chronological_split_time = dateFormat.parse(chronological_split); } catch (ParseException e) { usage("Could not interpret argument of --chronological-split as number or date and time: " + chronological_split); } // check for conflicts if (cross_validation > 1) usage("--cross-validation=K and --chronological-split=NUM|DATETIME are mutually exclusive."); if (test_ratio > 1) usage("--test-ratio=NUM and --chronological-split=NUM|DATETIME are mutually exclusive."); } } static void loadData( String data_dir, String user_attributes_file, String item_attributes_file, String user_relation_file, String item_relation_file, boolean static_data) throws Exception { long start = Calendar.getInstance().getTimeInMillis(); // Read training data if ((recommender instanceof TimeAwareRatingPredictor || chronological_split != null) && file_format != RatingFileFormat.MOVIELENS_1M) { training_data = TimedRatingData.read(Utils.combine(data_dir, training_file), user_mapping, item_mapping, false); } else { if (file_format == RatingFileFormat.DEFAULT) training_data = static_data ? StaticRatingData.read(Utils.combine(data_dir, training_file), user_mapping, item_mapping, rating_type, false) : RatingData.read(Utils.combine(data_dir, training_file), user_mapping, item_mapping, false); else if(file_format == RatingFileFormat.IGNORE_FIRST_LINE) training_data = static_data ? StaticRatingData.read(Utils.combine(data_dir, training_file), user_mapping, item_mapping, rating_type, true) : RatingData.read(Utils.combine(data_dir, training_file), user_mapping, item_mapping, true); else if (file_format == RatingFileFormat.MOVIELENS_1M) training_data = MovieLensRatingData.read(Utils.combine(data_dir, training_file), user_mapping, item_mapping); else if (file_format == RatingFileFormat.KDDCUP_2011) training_data = org.mymedialite.io.kddcup2011.Ratings.read(Utils.combine(data_dir, training_file)); } recommender.setRatings(training_data); // User attributes if (user_attributes_file != null) user_attributes = AttributeData.read(Utils.combine(data_dir, user_attributes_file), user_mapping, attribute_mapping); if (recommender instanceof IUserAttributeAwareRecommender) ((IUserAttributeAwareRecommender)recommender).setUserAttributes(user_attributes); // Item attributes if (item_attributes_file != null) item_attributes = AttributeData.read(Utils.combine(data_dir, item_attributes_file), item_mapping, attribute_mapping); if (recommender instanceof IItemAttributeAwareRecommender) ((IItemAttributeAwareRecommender)recommender).setItemAttributes(item_attributes); // User relation if (recommender instanceof IUserRelationAwareRecommender) { ((IUserRelationAwareRecommender)recommender).setUserRelation(RelationData.read(Utils.combine(data_dir, user_relation_file), user_mapping)); System.out.println("relation over " + ((IUserRelationAwareRecommender)recommender).numUsers() + " users"); } // Item relation if (recommender instanceof IItemRelationAwareRecommender) { ((IItemRelationAwareRecommender)recommender).setItemRelation(RelationData.read(Utils.combine(data_dir, item_relation_file), item_mapping)); System.out.println("relation over " + ((IItemRelationAwareRecommender)recommender).getNumItems() + " items"); } // Read test data if (test_file != null) { if (recommender instanceof TimeAwareRatingPredictor && file_format != RatingFileFormat.MOVIELENS_1M) test_data = TimedRatingData.read(Utils.combine(data_dir, test_file), user_mapping, item_mapping, false); else if (file_format == RatingFileFormat.MOVIELENS_1M) test_data = MovieLensRatingData.read(Utils.combine(data_dir, test_file), user_mapping, item_mapping); else if (file_format == RatingFileFormat.KDDCUP_2011) test_data = org.mymedialite.io.kddcup2011.Ratings.read(Utils.combine(data_dir, training_file)); else test_data = StaticRatingData.read(Utils.combine(data_dir, test_file), user_mapping, item_mapping, rating_type, file_format == RatingFileFormat.IGNORE_FIRST_LINE); } System.out.println("Loading time: " + (double)(Calendar.getInstance().getTimeInMillis() - start) / 1000 + " seconds"); System.out.println("Memory: " + Memory.getUsage() + " MB"); } static void displayStats() { if (training_time_stats.size() > 0) // TODO format floating point System.out.println("Iteration time: min=" + Collections.min(training_time_stats) + ", max=" + Collections.max(training_time_stats) + ", avg=" + Utils.average(training_time_stats) + " seconds"); if (eval_time_stats.size() > 0) System.out.println("Evaluation time: min=" + Collections.min(eval_time_stats) + ", max=" + Collections.max(eval_time_stats) + ", avg=" + Utils.average(eval_time_stats) + " seconds"); if (compute_fit && fit_time_stats.size() > 0) System.out.println("fit_time: min=" + Collections.min(fit_time_stats) + ", max=" + Collections.max(fit_time_stats) + ", avg=" + Utils.average(fit_time_stats) + " seconds"); System.out.println("Memory: " + Memory.getUsage() + " MB"); } }