// Copyright (C) 2010, 2011 Zeno Gantner, Chris Newell // // This file is part of MyMediaLite. // // MyMediaLite is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // MyMediaLite is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>. package org.mymedialite.eval; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; import org.mymedialite.IRecommender; import org.mymedialite.data.IPosOnlyFeedback; import org.mymedialite.datatype.SparseBooleanMatrix; import org.mymedialite.eval.measures.AUC; import org.mymedialite.eval.measures.NDCG; import org.mymedialite.eval.measures.PrecisionAndRecall; import org.mymedialite.eval.measures.ReciprocalRank; /** * Evaluation class for filtered item recommendation. * @version 2.03 */ public class ItemsFiltered { // Prevent instantiation. private ItemsFiltered() {} /** * For a given user and the test dataset, return a dictionary of items filtered by attributes. * @param user_id the user ID * @param test the test dataset * @param item_attributes * @return a dictionary containing a mapping from attribute IDs to collections of item IDs */ public static HashMap<Integer, Collection<Integer>> getFilteredItems(int user_id, IPosOnlyFeedback test, SparseBooleanMatrix item_attributes) { HashMap<Integer, Collection<Integer>> filtered_items = new HashMap<Integer, Collection<Integer>>(); for (int item_id : test.userMatrix().get(user_id)) { for (int attribute_id : item_attributes.get(item_id)) { if (!filtered_items.containsKey(attribute_id)) filtered_items.put(attribute_id, new HashSet<Integer>()); filtered_items.get(attribute_id).add(item_id); } } return filtered_items; } /** * Evaluation for rankings of filtered items. * * * @param recommender item recommender * @param test test cases * @param train training data * @param item_attributes the item attributes to be used for filtering * @param test_users a collection of integers with all test users * @param candidate_items a collection of integers with all candidate items * @param repeated_events allow repeated events in the evaluation (i.e. items accessed by a user before may be in the recommended list) * @return a dictionary containing the evaluation results * @throws Exception */ public static ItemRecommendationEvaluationResults evaluateFiltered( IRecommender recommender, IPosOnlyFeedback test, IPosOnlyFeedback train, SparseBooleanMatrix item_attributes, List<Integer> test_users, List<Integer> candidate_items, boolean repeated_events) throws Exception { SparseBooleanMatrix items_by_attribute = (SparseBooleanMatrix) item_attributes.transpose(); int num_users = 0; int num_lists = 0; ItemRecommendationEvaluationResults result = new ItemRecommendationEvaluationResults(); result.put("AUC", 0.0); result.put("MAP", 0.0); result.put("NDCG", 0.0); result.put("MRR", 0.0); result.put("prec@5", 0.0); result.put("prec@10", 0.0); result.put("recall@5", 0.0); result.put("recall@10", 0.0); // TODO explore parallel processing options OUTER: for (int user_id : test_users) { HashMap<Integer, Collection<Integer>> filtered_items = getFilteredItems(user_id, test, item_attributes); int last_user_id = -1; for (int attribute_id : filtered_items.keySet()) { HashSet<Integer> filtered_candidate_items = new HashSet<Integer>(items_by_attribute.get(attribute_id)); filtered_candidate_items.retainAll(candidate_items); HashSet<Integer> correct_items = new HashSet<Integer>(filtered_items.get(attribute_id)); correct_items.retainAll(filtered_candidate_items); // the number of candidate items for this user HashSet<Integer> candidate_items_in_train = new HashSet<Integer>(train.userMatrix().get(user_id)); candidate_items_in_train.retainAll(filtered_candidate_items); int num_eval_items = filtered_candidate_items.size() - candidate_items_in_train.size(); // skip all users that have 0 or #filtered_candidate_items test items if (correct_items.size() == 0) continue OUTER; if (num_eval_items - correct_items.size() == 0) continue OUTER; // evaluation List<Integer> prediction_list = org.mymedialite.itemrec.Extensions.predictItems(recommender, user_id, filtered_candidate_items); Collection<Integer> ignore_items = repeated_events ? new ArrayList<Integer>() : train.userMatrix().get(user_id); double auc = AUC.compute(prediction_list, correct_items, ignore_items); double map = PrecisionAndRecall.AP(prediction_list, correct_items, ignore_items); double ndcg = NDCG.compute(prediction_list, correct_items, ignore_items); double rr = ReciprocalRank.compute(prediction_list, correct_items, ignore_items); int[] positions = new int[] { 5, 10 }; HashMap<Integer, Double> prec = PrecisionAndRecall.precisionAt(prediction_list, correct_items, ignore_items, positions); HashMap<Integer, Double> recall = PrecisionAndRecall.recallAt(prediction_list, correct_items, ignore_items, positions); // counting stats num_lists++; if (last_user_id != user_id) { last_user_id = user_id; num_users++; } // result bookkeeping result.put("AUC", result.get("AUC") + auc); result.put("MAP", result.get("MAP") + map); result.put("NDCG", result.get("NDCG") + ndcg); result.put("MRR", result.get("MRR") + rr); result.put("prec@5", result.get("prec@5") + prec.get(5)); result.put("prec@10", result.get("prec@10") + prec.get(10)); result.put("recall@5", result.get("recall@5") + recall.get(5)); result.put("recall@10", result.get("recall@10") + recall.get(10)); if (prediction_list.size() != filtered_candidate_items.size()) throw new Exception("Not all items have been ranked."); if (num_lists % 5000 == 0) System.err.print("."); if (num_lists % 300000 == 0) System.err.println(); } } for (String measure : Items.getMeasures()) result.put(measure, result.get(measure) / num_lists); result.put("num_users", (double)num_users); result.put("num_lists", (double)num_lists); result.put("num_items", (double)candidate_items.size()); return result; } // TODO implement online eval }