// Copyright (C) 2010 Steffen Rendle, Zeno Gantner // Copyright (C) 2011 Zeno Gantner, Chris Newell // // This file is part of MyMediaLite. // // MyMediaLite is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // MyMediaLite is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>. package org.mymedialite.itemrec; import it.unimi.dsi.fastutil.ints.IntCollection; import it.unimi.dsi.fastutil.ints.IntSet; import java.io.IOException; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Set; import org.mymedialite.IRecommender; import org.mymedialite.data.IEntityMapping; import org.mymedialite.data.IPosOnlyFeedback; import org.mymedialite.data.IdentityMapping; import org.mymedialite.data.WeightedItem; /** * Class that contains static methods for item prediction. * @version 2.03 */ public class Extensions { // Prevent instantiation. private Extensions() {} /** * Write item predictions (scores) to a file. * @param recommender the IRecommender to use for making the predictions * @param train a user-wise IPosOnlyFeedback containing the items already observed * @param candidate_items the list of candidate items * @param num_predictions the number of items to return per user, -1 if there should be no limit * @param filename the name of the file to write to * @param users a list of users to make recommendations for * @param user_mapping an IEntityMapping object for the user IDs * @param item_mapping an IEntityMapping object for the item IDs * @throws IOException */ public static void writePredictions( IRecommender recommender, IPosOnlyFeedback train, Collection<Integer> candidate_items, int num_predictions, String filename, List<Integer> users, IEntityMapping user_mapping, IEntityMapping item_mapping) throws IOException { PrintWriter writer = new PrintWriter(filename); writePredictions(recommender, train, candidate_items, num_predictions, writer, users, user_mapping, item_mapping); } /** * Write item predictions (scores) to a TextWriter object. * @param recommender the IRecommender to use for making the predictions * @param train a user-wise IPosOnlyFeedback containing the items already observed * @param candidate_items the list of candidate items * @param num_predictions the number of items to return per user, -1 if there should be no limit * @param writer the TextWriter to write to * @param users a list of users to make recommendations for; if null, all users in train will be provided with recommendations * @param user_mapping an IEntityMapping object for the user IDs * @param item_mapping an IEntityMapping object for the item IDs */ public static void writePredictions( IRecommender recommender, IPosOnlyFeedback train, Collection<Integer> candidate_items, int num_predictions, PrintWriter writer, List<Integer> users, IEntityMapping user_mapping, IEntityMapping item_mapping) { // TODO check why the supplied users are ignored. Should there be a null check? users = new ArrayList<Integer>(train.allUsers()); for (int user_id : users) { IntCollection ignore_items = train.userMatrix().get(user_id); writePredictions(recommender, user_id, candidate_items, ignore_items, num_predictions, writer, user_mapping, item_mapping); } } /** * Write item predictions (scores) to a TextWriter object. * @param recommender the <see cref="IRecommender"/> to use for making the predictions * @param user_id the ID of the user to make recommendations for * @param candidate_items the list of candidate items * @param ignore_items a list of items for which no predictions should be made * @param num_predictions the number of items to return per user, -1 if there should be no limit * @param writer the <see cref="TextWriter"/> to write to * @param user_mapping an <see cref="IEntityMapping"/> object for the user IDs * @param item_mapping an <see cref="IEntityMapping"/> object for the item IDs */ public static void writePredictions( IRecommender recommender, int user_id, Collection<Integer> candidate_items, Collection<Integer> ignore_items, int num_predictions, PrintWriter writer, IEntityMapping user_mapping, IEntityMapping item_mapping) { if (user_mapping == null) user_mapping = new IdentityMapping(); if (item_mapping == null) item_mapping = new IdentityMapping(); List<WeightedItem> score_list = new ArrayList<WeightedItem>(); for (int item_id : candidate_items) score_list.add( new WeightedItem(item_id, recommender.predict(user_id, item_id))); Collections.sort(score_list, Collections.reverseOrder()); int prediction_count = 0; writer.print(user_mapping.toOriginalID(user_id) + "\t["); for (WeightedItem wi : score_list) { if (!ignore_items.contains(wi.item_id) && wi.weight > Double.MIN_VALUE) { if (prediction_count == 0) writer.print(item_mapping.toOriginalID(wi.item_id) + ":" + wi.weight.toString()); else writer.print("," + item_mapping.toOriginalID(wi.item_id) + ":" + wi.weight.toString()); prediction_count++; } if (prediction_count == num_predictions) break; } writer.println("]"); } /** * predict items for a specific users. * @param recommender the <see cref="IRecommender"/> object to use for the predictions * @param user_id the user ID * @param max_item_id the maximum item ID * @return a list sorted list of item IDs */ public static List<Integer> predictItems(IRecommender recommender, int user_id, int max_item_id) { List<Integer> items = new ArrayList<Integer>(max_item_id + 1); for (int i = 0; i < max_item_id; i++) items.add(i); return predictItems(recommender, user_id, items); } /** * Predict items for a given user. * @param recommender the recommender to use * @param user_id the numerical ID of the user * @param candidate_items a collection of numerical IDs of candidate items * @return an ordered list of items, the most likely item first */ public static List<Integer> predictItems(IRecommender recommender, int user_id, Collection<Integer> candidate_items) { ArrayList<WeightedItem> result = new ArrayList<WeightedItem>(candidate_items.size()); for (int item_id : candidate_items) { double predict = recommender.predict(user_id, item_id); result.add(new WeightedItem(item_id, predict)); } Collections.sort(result, Collections.reverseOrder()); List<Integer> return_array = new ArrayList<Integer>(result.size()); for (int i = 0; i < result.size(); i++) return_array.add(i, result.get(i).item_id); return return_array; } // /** // * Predict items for a specific user // * <param name="recommender">the <see cref="IRecommender"/> object to use for the predictions</param> // * <param name="user_id">the user ID</param> // * <param name="max_item_id">the maximum item ID</param> // * <returns>a list sorted list of item IDs</returns> // */ // public static int[] predictItems(IRecommender recommender, int user_id, int max_item_id) { // ArrayList<WeightedItem> result = new ArrayList<WeightedItem>(); // for (int item_id = 0; item_id <= max_item_id; item_id++) { // result.add( new WeightedItem(item_id, recommender.predict(user_id, item_id))); // } // Collections.sort(result, Collections.reverseOrder()); // // int[] return_array = new int[max_item_id + 1]; // for (int i=0; i<return_array.length; i++) { // return_array[i] = result.get(i).item_id; // } // return return_array; // } // // /** // * Predict items for a given user. // * @param recommender the recommender to use // * @param user_id the numerical ID of the user // * @param relevant_items a collection of numerical IDs of relevant items // * @return an ordered list of items, the most likely item first // */ // public static int[] predictItems(IRecommender recommender, int user_id, Collection<Integer> relevant_items) { // ArrayList<WeightedItem> result = new ArrayList<WeightedItem>(); // for (int item_id : relevant_items) { // result.add( new WeightedItem(item_id, recommender.predict(user_id, item_id))); // } // Collections.sort(result, Collections.reverseOrder()); // // int[] return_array = new int[result.size()]; // for (int i=0; i<return_array.length; i++) { // return_array[i] = result.get(i).item_id; // } // return return_array; // } }