// Copyright (C) 2010 Zeno Gantner, Steffen Rendle // Copyright (C) 2011 Zeno Gantner, Chris Newell // //This file is part of MyMediaLite. // //MyMediaLite is free software: you can redistribute it and/or modify //it under the terms of the GNU General Public License as published by //the Free Software Foundation, either version 3 of the License, or //(at your option) any later version. // //MyMediaLite is distributed in the hope that it will be useful, //but WITHOUT ANY WARRANTY; without even the implied warranty of //MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //GNU General Public License for more details. // //You should have received a copy of the GNU General Public License //along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>. package org.mymedialite.eval; import java.util.*; import org.mymedialite.data.IPosOnlyFeedback; import org.mymedialite.data.ISplit; import org.mymedialite.datatype.IBooleanMatrix; import org.mymedialite.eval.measures.AUC; import org.mymedialite.eval.measures.NDCG; import org.mymedialite.eval.measures.PrecisionAndRecall; import org.mymedialite.eval.measures.ReciprocalRank; import org.mymedialite.IRecommender; import org.mymedialite.itemrec.Extensions; import org.mymedialite.itemrec.ItemRecommender; import org.mymedialite.util.Utils; /** * Evaluation class for item recommendation. * @version 2.03 */ public class Items { // this is a static class, but Java does not allow us to declare that ;-) private Items() {} /** * Get the evaluation measures for item prediction offered by the class. */ static public Collection<String> getMeasures() { String[] array = { "AUC", "prec@5", "prec@10", "MAP", "recall@5", "recall@10", "NDCG", "MRR" }; return Arrays.asList(array); } /** * Evaluation for rankings of item recommenders. * User-item combinations that appear in both sets are ignored for the test set, and thus in the evaluation. * The evaluation measures are listed in the ItemPredictionMeasures property. * Additionally, 'num_users' and 'num_items' report the number of users that were used to compute the results * and the number of items that were taken into account. * * Literature: * C. Manning, P. Raghavan, H. Schütze: Introduction to Information Retrieval, Cambridge University Press, 2008 * * @param recommender item recommender * @param test test cases * @param train training data * @param test_users a collection of integers with all relevant_users * @param candidate_items a collection of integers with all candidate items * @return a dictionary containing the evaluation results */ public static ItemRecommendationEvaluationResults evaluate( IRecommender recommender, IPosOnlyFeedback test, IPosOnlyFeedback train, Collection<Integer> test_users, Collection<Integer> candidate_items) { return evaluate(recommender, test, train, test_users, candidate_items, null, null); } /** * Evaluation for rankings of items. * * User-item combinations that appear in both sets are ignored for the test set, and thus in the evaluation, * except when the boolean argument repeated_events is set. * * The evaluation measures are listed in the ItemPredictionMeasures property. * Additionally, 'num_users' and 'num_items' report the number of users that were used to compute the results * and the number of items that were taken into account. * * Literature: * C. Manning, P. Raghavan, H. Schütze: Introduction to Information Retrieval, Cambridge University Press, 2008 * * @param recommender item recommender * @param test test cases * @param training training data * @param test_users a collection of integers with all relevant users * @param candidate_items a collection of integers with all relevant items * @param candidate_item_mode the mode used to determine the candidate items. The default is CandidateItems.OVERLAP * @param repeated_events allow repeated events in the evaluation (i.e. items accessed by a user before may be in the recommended list). The default is false. * @return a dictionary containing the evaluation results */ public static ItemRecommendationEvaluationResults evaluate( IRecommender recommender, IPosOnlyFeedback test, IPosOnlyFeedback training, Collection<Integer> test_users, Collection<Integer> candidate_items, CandidateItems candidate_item_mode, Boolean repeated_events) { if(candidate_item_mode == null) candidate_item_mode = CandidateItems.OVERLAP; if(repeated_events == null) repeated_events = false; if(candidate_item_mode.equals(CandidateItems.TRAINING)) { candidate_items = training.allItems(); } else if(candidate_item_mode.equals(CandidateItems.TEST)) { candidate_items = test.allItems(); } else if(candidate_item_mode.equals(CandidateItems.OVERLAP)) { candidate_items = Utils.intersect(test.allItems(), training.allItems()); } else if(candidate_item_mode.equals(CandidateItems.UNION)) { candidate_items = Utils.union(test.allItems(), training.allItems()); } else if(candidate_item_mode.equals(CandidateItems.EXPLICIT)) { if (candidate_items == null) throw new IllegalArgumentException("candidate_items == null!"); } if (test_users == null) test_users = test.allUsers(); int num_users = 0; ItemRecommendationEvaluationResults result = new ItemRecommendationEvaluationResults(); IBooleanMatrix training_user_matrix = training.userMatrix(); IBooleanMatrix test_user_matrix = test.userMatrix(); for (Integer user_id : test_users) { // Items viewed by the user in the test set that were also present in the training set. HashSet<Integer> correct_items = new HashSet<Integer>(Utils.intersect(test_user_matrix.get(user_id), candidate_items)); // The number of items that will be used for this user. HashSet<Integer> candidate_items_in_train = new HashSet<Integer> (Utils.intersect(training_user_matrix.get(user_id), candidate_items)); int num_eval_items = candidate_items.size() - (repeated_events ? 0 : candidate_items_in_train.size()); // Skip all users that have 0 or #relevant_items test items. if (correct_items.size() == 0) continue; if (num_eval_items - correct_items.size() == 0) continue; List<Integer> prediction_list = Extensions.predictItems(recommender, user_id, candidate_items); //System.out.println("correct_items: " + correct_items.size()); //System.out.println("num_eval_items: " + num_eval_items); //System.out.println("candidate_items: " + candidate_items.size()); //System.out.println("training items: " + training.getUserMatrix().getRow(user_id).size()); //System.out.println("prediction_list: " + prediction_list.size()); if (prediction_list.size() != candidate_items.size()) throw new RuntimeException("Not all items have been ranked."); Collection<Integer> ignore_items = repeated_events ? new ArrayList<Integer>() : training_user_matrix.get(user_id); double auc = AUC.compute(prediction_list, correct_items, ignore_items); double map = PrecisionAndRecall.AP(prediction_list, correct_items, ignore_items); double ndcg = NDCG.compute(prediction_list, correct_items, ignore_items); double rr = ReciprocalRank.compute(prediction_list, correct_items, ignore_items); int[] positions = new int[] { 5, 10 }; Map<Integer, Double> prec = PrecisionAndRecall.precisionAt(prediction_list, correct_items, ignore_items, positions); Map<Integer, Double> recall = PrecisionAndRecall.recallAt(prediction_list, correct_items, ignore_items, positions); num_users++; result.put("AUC", result.get("AUC") + auc); result.put("MAP", result.get("MAP") + map); result.put("NDCG", result.get("NDCG") + ndcg); result.put("MRR", result.get("MRR") + rr); result.put("prec@5", result.get("prec@5") + prec.get(5)); result.put("prec@10", result.get("prec@10") + prec.get(10)); result.put("recall@5", result.get("recall@5") + recall.get(5)); result.put("recall@10", result.get("recall@10") + recall.get(10)); if (num_users % 1000 == 0) System.out.print("."); if (num_users % 60000 == 0) System.out.println(); } if (num_users > 1000) System.out.println(); for(String measure : getMeasures()) result.put(measure, result.get(measure) / num_users); result.put("num_users", (double)num_users); result.put("num_lists", (double)num_users); result.put("num_items", (double)candidate_items.size()); return result; } /** * Format item prediction results. * @param result the result dictionary * @return a string containing the results */ public static String formatResults(Map<String, Double> result) { String string = "AUC " + result.get("AUC") + " prec@5 " + result.get("prec@5") + " prec@10 " + result.get("prec@10") + " MAP " + result.get("MAP") + " recall@5 " + result.get("recall@5") + " recall@10 " + result.get("recall@10") + " NDCG " + result.get("NDCG") + " MRR " + result.get("MRR") + " num_users " + result.get("num_users") + " num_items " + result.get("num_items") + " num_lists " + result.get("num_lists"); return string; } /** * Display item prediction results. * @param result the result dictionary */ static public void displayResults(HashMap<String, Double> result) { System.out.println("AUC " + result.get("AUC")); System.out.println("prec@5 " + result.get("prec@5")); System.out.println("prec@10 " + result.get("prec@10")); System.out.println("MAP " + result.get("MAP")); System.out.println("recall@5 " + result.get("recall@5")); System.out.println("recall@10 " + result.get("recall@10")); System.out.println("NDCG " + result.get("NDCG")); System.out.println("MRR " + result.get("MRR")); System.out.println("num_users " + result.get("num_users")); System.out.println("num_items " + result.get("num_items")); System.out.println("num_lists " + result.get("num_lists")); } }