// Copyright (C) 2010 Zeno Gantner, Steffen Rendle
// Copyright (C) 2011 Zeno Gantner, Chris Newell
//
// This file is part of MyMediaLite.
//
// MyMediaLite is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// MyMediaLite is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>.
package org.mymedialite.eval;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import org.mymedialite.data.IPosOnlyFeedback;
import org.mymedialite.datatype.SparseBooleanMatrix;
import org.mymedialite.eval.measures.AUC;
import org.mymedialite.eval.measures.NDCG;
import org.mymedialite.eval.measures.PrecisionAndRecall;
import org.mymedialite.eval.measures.ReciprocalRank;
import org.mymedialite.grouprec.GroupRecommender;
/**
* Evaluation class for group recommendation.
* @version 2.03
*/
public class Groups {
// Prevent instantiation.
private Groups() {}
/**
* Evaluation for rankings of items recommended to groups.
*
*
* @param recommender group recommender
* @param test test cases
* @param train training data
* @param group_to_user group to user relation
* @param candidate_items a collection of integers with all candidate items
* @param ignore_overlap if true, ignore items that appear for a group in the training set when evaluating for that user
* @return a dictionary containing the evaluation results
* @throws Exception
*/
public static ItemRecommendationEvaluationResults evaluate(
GroupRecommender recommender,
IPosOnlyFeedback test,
IPosOnlyFeedback train,
SparseBooleanMatrix group_to_user,
Collection<Integer> candidate_items,
boolean ignore_overlap) throws Exception {
ItemRecommendationEvaluationResults result = new ItemRecommendationEvaluationResults();
int num_groups = 0;
for (int group_id : group_to_user.nonEmptyRowIDs()) {
List<Integer> users = group_to_user.getEntriesByRow(group_id);
HashSet<Integer> correct_items = new HashSet<Integer>();
for (int user_id : users)
correct_items.addAll(test.userMatrix().get(user_id));
correct_items.retainAll(candidate_items);
HashSet<Integer> candidate_items_in_train = new HashSet<Integer>();
for (int user_id : users)
candidate_items_in_train.addAll(train.userMatrix().get(user_id));
candidate_items_in_train.retainAll(candidate_items);
int num_eval_items = candidate_items.size() - (ignore_overlap ? candidate_items_in_train.size() : 0);
// skip all groups that have 0 or #candidate_items test items
if (correct_items.size() == 0) continue;
if (num_eval_items - correct_items.size() == 0) continue;
List<Integer> prediction_list = recommender.rankItems(users, candidate_items);
if (prediction_list.size() != candidate_items.size())
throw new Exception("Not all items have been ranked.");
HashSet<Integer> ignore_items = ignore_overlap ? candidate_items_in_train : new HashSet<Integer>();
double auc = AUC.compute(prediction_list, correct_items, ignore_items);
double map = PrecisionAndRecall.AP(prediction_list, correct_items, ignore_items);
double ndcg = NDCG.compute(prediction_list, correct_items, ignore_items);
double rr = ReciprocalRank.compute(prediction_list, correct_items, ignore_items);
int[] positions = new int[] { 5, 10 };
HashMap<Integer, Double> prec = PrecisionAndRecall.precisionAt(prediction_list, correct_items, ignore_items, positions);
HashMap<Integer, Double> recall = PrecisionAndRecall.recallAt(prediction_list, correct_items, ignore_items, positions);
// thread-safe incrementing
num_groups++;
result.put("AUC", result.get("AUC") + auc);
result.put("MAP", result.get("MAP") + map);
result.put("NDCG", result.get("NDCG") + ndcg);
result.put("MRR", result.get("MRR") + rr);
result.put("prec@5", result.get("prec@5") + prec.get(5));
result.put("prec@10", result.get("prec@10") + prec.get(10));
result.put("recall@5", result.get("recall@5") + recall.get(5));
result.put("recall@10", result.get("recall@10") + recall.get(10));
if (num_groups % 1000 == 0)
System.err.print(".");
if (num_groups % 60000 == 0)
System.err.println();
}
result.put("num_groups", (double)num_groups);
result.put("num_lists", (double)num_groups);
result.put("num_items", (double)candidate_items.size());
return result;
}
}