// Copyright (C) 2011 Zeno Gantner, Chris Newell // // This file is part of MyMediaLite. // // MyMediaLite is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // MyMediaLite is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>. package org.mymedialite.eval; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import org.mymedialite.IIterativeModel; import org.mymedialite.IRecommender; import org.mymedialite.data.IPosOnlyFeedback; import org.mymedialite.data.ISplit; import org.mymedialite.data.PosOnlyFeedback; import org.mymedialite.data.PosOnlyFeedbackCrossValidationSplit; import org.mymedialite.datatype.SparseBooleanMatrix; import org.mymedialite.itemrec.ItemRecommender; /** * Cross-validation for item recommendation. * @version 2.03 */ public class ItemsCrossValidation { // Prevent instantiation. private ItemsCrossValidation() {} /** * Evaluate on the folds of a dataset split. * @param recommender an item recommender * @param num_folds the number of folds * @param test_users a collection of integers with all test users * @param candidate_items a collection of integers with all candidate items * @param candidate_item_mode the mode used to determine the candidate items * @param show_results set to true to print results to STDERR * @return a dictionary containing the average results over the different folds of the split */ public static ItemRecommendationEvaluationResults doCrossValidation( IRecommender recommender, int num_folds, List<Integer> test_users, List<Integer> candidate_items, CandidateItems candidate_item_mode, boolean show_results) throws Exception { if (!(recommender instanceof ItemRecommender)) throw new IllegalArgumentException("recommender must be of type ItemRecommender"); if(candidate_item_mode == null) candidate_item_mode = CandidateItems.OVERLAP; List<IPosOnlyFeedback> train = new ArrayList<IPosOnlyFeedback>(num_folds); List<IPosOnlyFeedback> test = new ArrayList<IPosOnlyFeedback>(num_folds); for(int i=0 ; i<num_folds; i++) { train.add(new PosOnlyFeedback<SparseBooleanMatrix>(SparseBooleanMatrix.class)); test.add(new PosOnlyFeedback<SparseBooleanMatrix>(SparseBooleanMatrix.class)); } PosOnlyFeedbackCrossValidationSplit<PosOnlyFeedback<SparseBooleanMatrix>> split = new PosOnlyFeedbackCrossValidationSplit<PosOnlyFeedback<SparseBooleanMatrix>>(((ItemRecommender) recommender).getFeedback(), num_folds, train, test); return doCrossValidation(recommender, split, test_users, candidate_items, candidate_item_mode, show_results); } /** * Evaluate on the folds of a dataset split. * @param recommender an item recommender * @param split a dataset split * @param test_users a collection of integers with all test users * @param candidate_items a collection of integers with all candidate items * @param candidate_item_mode the mode used to determine the candidate items * @param show_results set to true to print results to STDERR * @return a dictionary containing the average results over the different folds of the split * @throws Exception */ public static ItemRecommendationEvaluationResults doCrossValidation(IRecommender recommender, ISplit<IPosOnlyFeedback> split, List<Integer> test_users, List<Integer> candidate_items, CandidateItems candidate_item_mode, boolean show_results) throws Exception { if (!(recommender instanceof ItemRecommender)) throw new IllegalArgumentException("recommender must be of type ItemRecommender"); if(candidate_item_mode == null) candidate_item_mode = CandidateItems.OVERLAP; ItemRecommendationEvaluationResults avg_results = new ItemRecommendationEvaluationResults(); for(int fold = 0 ; fold < split.numberOfFolds(); fold++) { try { ItemRecommender split_recommender = ((ItemRecommender) recommender).clone(); // avoid changes : recommender split_recommender.setFeedback(split.train().get(fold)); split_recommender.train(); ItemRecommendationEvaluationResults fold_results = Items.evaluate(split_recommender, split.train().get(fold), split.test().get(fold), test_users, candidate_items, candidate_item_mode, false); for (String key : fold_results.keySet()) { if (avg_results.containsKey(key)) avg_results.put(key, avg_results.get(key) + fold_results.get(key)); else avg_results.put(key, fold_results.get(key)); } if (show_results) System.out.println("fold " + fold + " " + fold_results); } catch (Exception e) { System.err.println("===> ERROR: " + e.getMessage() + e.getStackTrace()); throw e; } } for (String key : Items.getMeasures()) avg_results.put(key, avg_results.get(key) / split.numberOfFolds()); avg_results.put("num_users", avg_results.get("num_users") / split.numberOfFolds()); avg_results.put("num_items", avg_results.get("num_items") / split.numberOfFolds()); return avg_results; } /** * Evaluate an iterative recommender on the folds of a dataset split, display results on STDOUT. * @param recommender an item recommender * @param num_folds the number of folds * @param test_users a collection of integers with all test users * @param candidate_items a collection of integers with all candidate items * @param candidate_item_mode the mode used to determine the candidate items * @param repeated_events allow repeated events in the evaluation (i.e. items accessed by a user before may be in the recommended list) * @param max_iter the maximum number of iterations * @param find_iter the report interval */ public static void doIterativeCrossValidation( IRecommender recommender, int num_folds, List<Integer> test_users, List<Integer> candidate_items, CandidateItems candidate_item_mode, boolean repeated_events, int max_iter, int find_iter) throws Exception { // find_iter = 1 if (!(recommender instanceof ItemRecommender)) throw new IllegalArgumentException("recommender must be of type ItemRecommender"); List<IPosOnlyFeedback> train = new ArrayList<IPosOnlyFeedback>(num_folds); List<IPosOnlyFeedback> test = new ArrayList<IPosOnlyFeedback>(num_folds); for(int i=0 ; i<num_folds; i++) { train.add(new PosOnlyFeedback<SparseBooleanMatrix>(SparseBooleanMatrix.class)); test.add(new PosOnlyFeedback<SparseBooleanMatrix>(SparseBooleanMatrix.class)); } PosOnlyFeedbackCrossValidationSplit<PosOnlyFeedback<SparseBooleanMatrix>> split = new PosOnlyFeedbackCrossValidationSplit<PosOnlyFeedback<SparseBooleanMatrix>>(((ItemRecommender) recommender).getFeedback(), num_folds, train, test); doIterativeCrossValidation(recommender, split, test_users, candidate_items, candidate_item_mode, repeated_events, max_iter, find_iter); } /** * Evaluate an iterative recommender on the folds of a dataset split, display results on STDOUT. * @param recommender an item recommender * @param split a positive-only feedback dataset split * @param test_users a collection of integers with all test users * @param candidate_items a collection of integers with all candidate items * @param candidate_item_mode the mode used to determine the candidate items * @param repeated_events allow repeated events in the evaluation (i.e. items accessed by a user before may be in the recommended list) * @param max_iter the maximum number of iterations * @param find_iter the report interval * @throws Exception */ public static void doIterativeCrossValidation( IRecommender recommender, ISplit<IPosOnlyFeedback> split, List<Integer> test_users, List<Integer> candidate_items, CandidateItems candidate_item_mode, boolean repeated_events, int max_iter, int find_iter) throws Exception { if (!(recommender instanceof IIterativeModel)) throw new IllegalArgumentException("recommender must be of type IIterativeModel"); if (!(recommender instanceof ItemRecommender)) throw new IllegalArgumentException("recommender must be of type ItemRecommender"); ItemRecommender[] split_recommenders = new ItemRecommender[split.numberOfFolds()]; IIterativeModel[] iterative_recommenders = new IIterativeModel[split.numberOfFolds()]; // initial training and evaluation for(int i=0; i < split.numberOfFolds(); i ++) { try { split_recommenders[i] = ((ItemRecommender) recommender).clone(); // to avoid changes : recommender split_recommenders[i].setFeedback(split.train().get(i)); split_recommenders[i].train(); iterative_recommenders[i] = (IIterativeModel) split_recommenders[i]; ItemRecommendationEvaluationResults fold_results = Items.evaluate(split_recommenders[i], split.test().get(i), split.train().get(i), test_users, candidate_items, candidate_item_mode, repeated_events); System.out.println("fold " + i + " " + fold_results + " iteration " + iterative_recommenders[i].getNumIter()); } catch (Exception e) { System.err.println("===> ERROR: " + e.getMessage() + e.getStackTrace()); throw e; } } // iterative training and evaluation for (int it = iterative_recommenders[0].getNumIter() + 1; it <= max_iter; it++) for(int i = 0; i < split.numberOfFolds(); i++) { try { iterative_recommenders[i].iterate(); if (it % find_iter == 0) { ItemRecommendationEvaluationResults fold_results = Items.evaluate(split_recommenders[i], split.test().get(i), split.train().get(i), test_users, candidate_items, candidate_item_mode, repeated_events); System.out.println("fold " + i + " " + fold_results + " iteration " + it); } } catch (Exception e) { System.err.println("===> ERROR: " + e.getMessage() + e.getStackTrace()); throw e; } } } }