import it.unimi.dsi.fastutil.ints.IntArraySet;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Calendar;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.mymedialite.IItemAttributeAwareRecommender;
import org.mymedialite.IItemRelationAwareRecommender;
import org.mymedialite.IIterativeModel;
import org.mymedialite.IRecommender;
import org.mymedialite.IUserAttributeAwareRecommender;
import org.mymedialite.IUserRelationAwareRecommender;
import org.mymedialite.data.EntityMapping;
import org.mymedialite.data.IEntityMapping;
import org.mymedialite.data.IPosOnlyFeedback;
import org.mymedialite.data.PosOnlyFeedback;
import org.mymedialite.data.PosOnlyFeedbackSimpleSplit;
import org.mymedialite.datatype.SparseBooleanMatrix;
import org.mymedialite.eval.CandidateItems;
import org.mymedialite.eval.Groups;
import org.mymedialite.eval.ItemRecommendationEvaluationResults;
import org.mymedialite.eval.Items;
import org.mymedialite.eval.ItemsCrossValidation;
import org.mymedialite.eval.ItemsFiltered;
import org.mymedialite.eval.ItemsOnline;
import org.mymedialite.grouprec.Average;
import org.mymedialite.grouprec.GroupRecommender;
import org.mymedialite.grouprec.Maximum;
import org.mymedialite.grouprec.Minimum;
import org.mymedialite.io.AttributeData;
import org.mymedialite.io.ItemData;
import org.mymedialite.io.ItemDataFileFormat;
import org.mymedialite.io.ItemDataRatingThreshold;
import org.mymedialite.io.Model;
import org.mymedialite.io.NumberFile;
import org.mymedialite.io.RelationData;
import org.mymedialite.itemrec.Extensions;
import org.mymedialite.itemrec.IIncrementalItemRecommender;
import org.mymedialite.itemrec.ItemRecommender;
import org.mymedialite.util.Handlers;
import org.mymedialite.util.Memory;
import org.mymedialite.util.Random;
import org.mymedialite.util.Recommender;
import org.mymedialite.util.Utils;
/**
* Item prediction program, see usage() method for more information.
*
* @version 2.03
*/
public class ItemRecommendation {
static final String VERSION = "2.03";
// Data
static IPosOnlyFeedback training_data;
static IPosOnlyFeedback test_data;
static List<Integer> test_users;
static List<Integer> candidate_items;
static SparseBooleanMatrix group_to_user; // rows: groups, columns: users
static Collection<Integer> user_groups;
static CandidateItems eval_item_mode = CandidateItems.UNION;
// Recommenders
static IRecommender recommender = null;
// ID mapping objects
static IEntityMapping user_mapping = new EntityMapping();
static IEntityMapping item_mapping = new EntityMapping();
static IEntityMapping attribute_mapping = new EntityMapping();
// User and item attributes
static SparseBooleanMatrix user_attributes;
static SparseBooleanMatrix item_attributes;
// Command-line parameters (data)
static String training_file;
static String test_file;
static ItemDataFileFormat file_format = ItemDataFileFormat.DEFAULT;
static String data_dir = "";
static String test_users_file;
static String candidate_items_file;
static String user_attributes_file;
static String item_attributes_file;
static String user_relations_file;
static String item_relations_file;
static String save_model_file = null;
static String load_model_file = null;
static String user_groups_file;
static String prediction_file;
// Command-line parameters (other)
static boolean compute_fit = false;
static int cross_validation;
static boolean show_fold_results = false;
static double test_ratio;
static double rating_threshold = Double.NaN;
static int num_test_users;
static int predict_items_number = -1;
static boolean online_eval = false;
static boolean filtered_eval = false;
static boolean repeat_eval = false;
static String group_method;
static boolean overlap_items = false;
static boolean in_training_items = false;
static boolean in_test_items = false;
static boolean all_items = false;
static boolean user_prediction = false;
static int random_seed = -1;
static int find_iter = 0;
// Time statistics
static List<Double> training_time_stats = new ArrayList<Double>();
static List<Double> fit_time_stats = new ArrayList<Double>();
static List<Double> eval_time_stats = new ArrayList<Double>();
static void showVersion() {
System.out.println("MyMediaLite Item Prediction from Implicit Feedback " + VERSION);
System.out.println("Copyright (C) 2010 Zeno Gantner, Steffen Rendle, Christoph Freudenthaler");
System.out.println("Copyright (C) 2011 Zeno Gantner, Chris Newell");
System.out.println("This is free software; see the source for copying conditions. There is NO");
System.out.println("warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.");
System.exit(0);
}
private static class ErrorHandler implements Recommender.ErrorHandler {
public void reportError(String message) {
usage(message);
}
}
static void usage(String message) {
System.out.println(message);
System.out.println();
usage(-1);
}
static void usage(int exit_code) {
System.out.println("MyMediaLite item recommendation from positive-only feedback " + VERSION);
System.out.println(" Usage: item_recommendation --training-file=FILE --recommender=METHOD [OPTIONS]");
System.out.println(" method ARGUMENTS have the form name=value");
System.out.println();
System.out.println(" General OPTIONS:\n" +
" --recommender=METHOD use METHOD for recommendations (default: MostPopular)\n" +
" --group-recommender=METHOD use METHOD to combine the predictions for several users\n" +
" --recommender-options=OPTIONS use OPTIONS as recommender options\n" +
" --help display this usage information and exit\n" +
" --version display version information and exit\n" +
" --random-seed=N initialize the random number generator with N\n" +
"\n" +
" Files:\n" +
" --training-file=FILE read training data from FILE\n" +
" --test-file=FILE read test data from FILE\n" +
" --file-format=ignore_first_line|default\n" +
" --data-dir=DIR load all files from DIR\n" +
" --user-attributes=FILE file containing user attribute information, 1 tuple per line\n" +
" --item-attributes=FILE file containing item attribute information, 1 tuple per line\n" +
" --user-relations=FILE file containing user relation information, 1 tuple per line\n" +
" --item-relations=FILE file containing item relation information, 1 tuple per line\n" +
" --user-groups=FILE file containing group-to-user mappings, 1 tuple per line\n" +
" --save-model=FILE save computed model to FILE\n" +
" --load-model=FILE load model from FILE\n" +
"\n" +
" Data interpretation:\n" +
" --user-prediction transpose the user-item matrix and perform user prediction instead of item prediction\n" +
" --rating-threshold=NUM (for rating datasets) interpret rating >= NUM as positive feedback\n" +
"\n" +
" Choosing the items for evaluation/prediction (mutually exclusive):\n" +
" --candidate-items=FILE use the items in FILE (one per line) as candidate items in the evaluation\n" +
" --overlap-items use only the items that are both in the training and the test set as candidate items in the evaluation\n" +
" --in-training-items use only the items in the training set as candidate items in the evaluation\n" +
" --in-test-items use only the items in the test set as candidate items in the evaluation\n" +
" --all-items use all known items as candidate items in the evaluation\n" +
" Choosing the users for evaluation/prediction\n" +
" --test-users=FILE predict items for users specified in FILE (one user per line)\n" +
"\n" +
" Prediction options:\n" +
" --prediction-file=FILE write ranked predictions to FILE, one user per line\n" +
" --predict-items-number=N predict N items per user (needs --predict-items-file)\n" +
"\n" +
" Evaluation options:\n" +
" --cross-validation=K perform k-fold cross-validation on the training data\n" +
" --show-fold-results show results for individual folds in cross-validation\n" +
" --test-ratio=NUM evaluate by splitting of a NUM part of the feedback\n" +
" --num-test-users=N evaluate on only N randomly picked users (to save time)\n" +
" --online-evaluation perform online evaluation (use every tested user-item combination for incremental training)\n" +
" --filtered-evaluation perform evaluation filtered by item attribute (expects --item-attributes=FILE)\n" +
" --repeat-evaluation assume that items can be accessed repeatedly - items can occur both in the training and the test data for one user\n" +
" --compute-fit display fit on training data\n" +
"\n" +
" Finding the right number of iterations (iterative methods)\n" +
" --find-iter=N give out statistics every N iterations\n" +
" --max-iter=N perform at most N iterations\n" +
" --auc-cutoff=NUM abort if AUC is below NUM\n" +
" --prec5-cutoff=NUM abort if prec@5 is below NUM\n");
System.exit(exit_code);
}
public static void main(String[] args) {
// Handlers for uncaught exceptions and interrupts
Thread.setDefaultUncaughtExceptionHandler(new Handlers());
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override
public void run() {
// Display stats if forced to exit
displayStats();
}
});
// Recommender arguments
String method = null;
String recommender_options = "";
// Help/version
boolean show_help = false;
boolean show_version = false;
// Variables for iteration search
int max_iter = 500;
double auc_cutoff = 0;
double prec5_cutoff = 0;
compute_fit = false;
// Other parameters
test_ratio = 0;
num_test_users = -1;
repeat_eval = false;
for(String arg : args) {
int div = arg.indexOf("=") + 1;
String name;
String value;
if(div > 0) {
name = arg.substring(0, div);
value = arg.substring(div);
} else {
name = arg;
value = null;
}
// String-valued options
if(name.equals("--training-file=")) training_file = value;
else if(name.equals("--test-file=")) test_file = value;
else if(name.equals("--recommender=")) method = value;
else if(name.equals("--group-recommender=")) group_method = value;
else if(name.equals("--recommender-options=")) recommender_options += " " + value;
else if(name.equals("--data-dir=")) data_dir = value;
else if(name.equals("--user-attributes=")) user_attributes_file = value;
else if(name.equals("--item-attributes=")) item_attributes_file = value;
else if(name.equals("--user-relations=")) user_relations_file = value;
else if(name.equals("--item-relations=")) item_relations_file = value;
else if(name.equals("--save-model=")) save_model_file = value;
else if(name.equals("--load-model=")) load_model_file = value;
else if(name.equals("--prediction-file=")) prediction_file = value;
else if(name.equals("--test-users=")) test_users_file = value;
else if(name.equals("--candidate-items=")) candidate_items_file = value;
else if(name.equals("--user-groups=")) user_groups_file = value;
// Integer-valued options
else if(name.equals("--find-iter=")) find_iter = Integer.parseInt(value);
else if(name.equals("--max-iter=")) max_iter = Integer.parseInt(value);
else if(name.equals("--random-seed=")) random_seed = Integer.parseInt(value);
else if(name.equals("--predict-items-number=")) predict_items_number = Integer.parseInt(value);
else if(name.equals("--num-test-users=")) num_test_users = Integer.parseInt(value);
else if(name.equals("--cross-validation=")) cross_validation = Integer.parseInt(value);
// Double-valued options
else if(name.equals("--auc-cutoff=")) auc_cutoff = Double.parseDouble(value);
else if(name.equals("--prec5-cutoff=")) prec5_cutoff = Double.parseDouble(value);
else if(name.equals("--test-ratio=")) test_ratio = Double.parseDouble(value);
else if(name.equals("--rating-threshold=")) rating_threshold = Double.parseDouble(value);
// Enum options
else if(name.equals("--file-format=")) file_format = ItemDataFileFormat.valueOf(value);
// Boolean options
else if(name.equals("--user-prediction")) user_prediction = true;
else if(name.equals("--compute-fit")) compute_fit = true;
else if(name.equals("--online-evaluation")) online_eval = true;
else if(name.equals("--filtered-evaluation")) filtered_eval = true;
else if(name.equals("--repeat-evaluation")) repeat_eval = true;
else if(name.equals("--show-fold-results")) show_fold_results = true;
else if(name.equals("--overlap-items")) overlap_items = true;
else if(name.equals("--all-items")) all_items = true;
else if(name.equals("--in-training-items")) in_training_items = true;
else if(name.equals("--in-test-items")) in_test_items = true;
else if(name.equals("--help")) show_help = true;
else if(name.equals("--version")) show_version = true;
else usage("Did not understand " + name);
}
boolean no_eval = true;
if (test_ratio > 0 || test_file != null) no_eval = false;
if (show_version) showVersion();
if (show_help) usage(0);
if (random_seed != -1) org.mymedialite.util.Random.initInstance(random_seed);
// Set up recommender
if (load_model_file != null)
try {
recommender = Model.load(load_model_file);
} catch (IOException e) {
System.err.println("Unable to load model file: " + load_model_file);
System.exit(0);
}
else if (method != null)
recommender = Recommender.createItemRecommender(method);
else
recommender = Recommender.createItemRecommender("MostPopular");
// In case something went wrong ...
if (recommender == null && method != null)
usage("Unknown recommendation method: " + method);
if (recommender == null && load_model_file != null)
usage("Could not load model from file " + load_model_file);
checkParameters();
try {
recommender = Recommender.configure(recommender, recommender_options, new ErrorHandler());
} catch (IllegalAccessException e) {
System.err.println("Unable to instantiate recommender: " + recommender.toString());
System.exit(0);
}
// Load all the data
try {
loadData();
} catch (Exception e) {
System.err.println("Unable to load data: " + e.getMessage());
}
Utils.displayDataStats(training_data, test_data, user_attributes, item_attributes);
// Evaluation
try {
if (find_iter != 0) {
if ( !(recommender instanceof IIterativeModel) )
usage("Only iterative recommenders (interface IIterativeModel) support --find-iter=N.");
IIterativeModel iterative_recommender = (IIterativeModel) recommender;
System.out.println("Recommender: " + recommender);
if (cross_validation > 1) {
ItemsCrossValidation.doIterativeCrossValidation(recommender, cross_validation, test_users, candidate_items, eval_item_mode, repeat_eval, max_iter, find_iter);
} else {
if (load_model_file == null)
recommender.train();
if (compute_fit)
System.out.println("Fit: " + computeFit() + " iteration " + iterative_recommender.getNumIter());
ItemRecommendationEvaluationResults results = evaluate();
System.out.println(results + " iteration " + iterative_recommender.getNumIter());
for (int it = iterative_recommender.getNumIter() + 1; it <= max_iter; it++) {
long start = Calendar.getInstance().getTimeInMillis();
iterative_recommender.iterate();
training_time_stats.add((double)(Calendar.getInstance().getTimeInMillis() - start) / 1000);
if (it % find_iter == 0) {
if (compute_fit) {
start = Calendar.getInstance().getTimeInMillis();
try {
System.out.println("Fit: " + computeFit() + " iteration " + it);
} catch (Exception e) {
System.err.println("Exception at line 390: " + e.getMessage());
}
fit_time_stats.add((double)(Calendar.getInstance().getTimeInMillis() - start) / 1000);
}
start = Calendar.getInstance().getTimeInMillis();
results = evaluate();
eval_time_stats.add((double)(Calendar.getInstance().getTimeInMillis() - start) / 1000);
System.out.println(results + " iteration " + it);
try {
Model.save(recommender, save_model_file, it);
} catch (IOException e) {
System.err.println("Unable to save model file: " + save_model_file);
}
predict(prediction_file, test_users_file, it);
if (results.get("AUC") < auc_cutoff || results.get("prec@5") < prec5_cutoff) {
System.out.println("Reached cutoff after " + it + " iterations");
System.out.println("DONE");
break;
}
}
} // for
}
} else {
System.out.println("Recommender: " + recommender);
if (load_model_file == null) {
if (cross_validation > 1) {
ItemRecommendationEvaluationResults results = ItemsCrossValidation.doCrossValidation(recommender, cross_validation, test_users, candidate_items, eval_item_mode, show_fold_results);
System.out.println(results);
no_eval = true;
} else {
long start = Calendar.getInstance().getTimeInMillis();
recommender.train();
System.out.println("Training time " + (double)(Calendar.getInstance().getTimeInMillis() - start) / 1000 + " seconds");
}
}
if (compute_fit)
System.out.println("Fit: " + computeFit());
if (prediction_file != null) {
predict(prediction_file, test_users_file);
} else if (!no_eval) {
if (online_eval) {
HashMap<String, Double> results = ItemsOnline.evaluate(recommender, test_data, training_data, test_users, candidate_items, eval_item_mode);
System.out.println(results);
} else if (group_method != null) {
GroupRecommender group_recommender = null;
System.out.println("Group recommendation strategy: " + group_method);
// TODO GroupUtils.CreateGroupRecommender(group_method, recommender);
if (group_method == "Average")
group_recommender = new Average(recommender);
else if (group_method == "Minimum")
group_recommender = new Minimum(recommender);
else if (group_method == "Maximum")
group_recommender = new Maximum(recommender);
else
usage("Unknown method : --group-recommender=METHOD");
ItemRecommendationEvaluationResults result = Groups.evaluate(group_recommender, test_data, training_data, group_to_user, candidate_items, true);
System.out.println(result);
} else {
long start = Calendar.getInstance().getTimeInMillis();
System.out.println(evaluate());
System.out.println("Testing time " + (double)(Calendar.getInstance().getTimeInMillis() - start) / 1000 + " seconds");
}
}
}
try {
Model.save(recommender, save_model_file);
} catch (IOException e) {
System.err.println("Unable to save model file: " + save_model_file);
}
} catch(Exception e) {
System.out.println("Exception: " + e.getMessage());
e.printStackTrace();
}
}
static void checkParameters() {
// TODO block group vs. filter/online, etc.
if (training_file == null)
usage("Parameter --training-file=FILE is missing.");
if (online_eval && filtered_eval)
usage("Combination of --online-eval and --filtered-eval is not (yet) supported.");
if (online_eval && !(recommender instanceof IIncrementalItemRecommender))
usage("Recommender" + recommender.getClass().getName() + " does not support incremental updates, which are necessary for an online experiment.");
if (cross_validation == 1)
usage("--cross-validation=K requires K to be at least 2.");
if (show_fold_results && cross_validation == 0)
usage("--show-fold-results only works with --cross-validation=K.");
if (cross_validation > 1 && test_ratio != 0)
usage("--cross-validation=K and --test-ratio=NUM are mutually exclusive.");
if (cross_validation > 1 && prediction_file != null)
usage("--cross-validation=K and --prediction-file=FILE are mutually exclusive.");
if (test_file == null && test_ratio == 0 && cross_validation == 0 && save_model_file == null && test_users_file == null)
usage("Please provide either test-file=FILE, --test-ratio=NUM, --cross-validation=K, --save-model=FILE, or --test-users=FILE.");
if ((candidate_items_file != null ? 1 : 0) + (all_items ? 1 : 0) + (in_training_items ? 1 : 0) + (in_test_items ? 1 : 0) + (overlap_items ? 1 : 0) > 1)
usage("--candidate-items=FILE, --all-items, --in-training-items, --in-test-items, and --overlap-items are mutually exclusive.");
if (test_file == null && test_ratio == 0 && cross_validation == 0 && overlap_items)
usage("--overlap-items only makes sense with either --test-file=FILE, --test-ratio=NUM, or cross-validation=K.");
if (test_file == null && test_ratio == 0 && cross_validation == 0 && in_test_items)
usage("--in-test-items only makes sense with either --test-file=FILE, --test-ratio=NUM, or cross-validation=K.");
if (test_file == null && test_ratio == 0 && cross_validation == 0 && in_training_items)
usage("--in-training-items only makes sense with either --test-file=FILE, --test-ratio=NUM, or cross-validation=K.");
if (group_method != null && user_groups_file == null)
usage("--group-recommender needs --user-groups=FILE.");
if (user_prediction) {
if (recommender instanceof IUserAttributeAwareRecommender || recommender instanceof IItemAttributeAwareRecommender ||
recommender instanceof IUserRelationAwareRecommender || recommender instanceof IItemRelationAwareRecommender)
usage("--user-prediction is not (yet) supported : combination with attribute- or relation-aware recommenders.");
if (filtered_eval)
usage("--user-prediction is not (yet) supported : combination with --filtered-evaluation.");
if (user_groups_file != null)
usage("--user-prediction is not (yet) supported : combination with --user-groups=FILE.");
}
if (recommender instanceof IUserAttributeAwareRecommender && user_attributes_file == null)
usage("Recommender expects --user-attributes=FILE.");
if (recommender instanceof IItemAttributeAwareRecommender && item_attributes_file == null)
usage("Recommender expects --item-attributes=FILE.");
if (filtered_eval && item_attributes_file == null)
usage("--filtered-evaluation expects --item-attributes=FILE.");
if (recommender instanceof IUserRelationAwareRecommender && user_relations_file == null)
usage("Recommender expects --user-relations=FILE.");
if (recommender instanceof IItemRelationAwareRecommender && user_relations_file == null)
usage("Recommender expects --item-relations=FILE.");
}
static void loadData() throws Exception {
long start = Calendar.getInstance().getTimeInMillis();
// training data
training_file = Utils.combine(data_dir, training_file);
training_data = Double.isNaN(rating_threshold)
? ItemData.read(training_file, user_mapping, item_mapping, file_format == ItemDataFileFormat.IGNORE_FIRST_LINE)
: ItemDataRatingThreshold.read(training_file, rating_threshold, user_mapping, item_mapping, file_format == ItemDataFileFormat.IGNORE_FIRST_LINE);
// User attributes
if (user_attributes_file != null)
user_attributes = AttributeData.read(Utils.combine(data_dir, user_attributes_file), user_mapping, attribute_mapping);
if (recommender instanceof IUserAttributeAwareRecommender)
((IUserAttributeAwareRecommender)recommender).setUserAttributes(user_attributes);
// Item attributes
if (item_attributes_file != null)
item_attributes = AttributeData.read(Utils.combine(data_dir, item_attributes_file), item_mapping, attribute_mapping);
if (recommender instanceof IItemAttributeAwareRecommender)
((IItemAttributeAwareRecommender)recommender).setItemAttributes(item_attributes);
// User relation
if (recommender instanceof IUserRelationAwareRecommender) {
((IUserRelationAwareRecommender)recommender).setUserRelation(RelationData.read(Utils.combine(data_dir, user_relations_file), user_mapping));
System.out.println("Relation over " + ((IUserRelationAwareRecommender)recommender).numUsers() + " users");
}
// Item relation
if (recommender instanceof IItemRelationAwareRecommender) {
((IItemRelationAwareRecommender)recommender).setItemRelation(RelationData.read(Utils.combine(data_dir, item_relations_file), item_mapping));
System.out.println("Relation over " + ((IItemRelationAwareRecommender)recommender).getNumItems() + " items");
}
// User groups
if (user_groups_file != null) {
group_to_user = RelationData.read(Utils.combine(data_dir, user_groups_file), user_mapping); // assumption: user and user group IDs are disjoint
user_groups = group_to_user.nonEmptyRowIDs();
System.out.println(user_groups.size() + " user groups");
}
// Test data
if (test_ratio == 0) {
if (test_file != null) {
test_file = Utils.combine(data_dir, test_file);
test_data = Double.isNaN(rating_threshold)
? ItemData.read(test_file, user_mapping, item_mapping, file_format == ItemDataFileFormat.IGNORE_FIRST_LINE)
: ItemDataRatingThreshold.read(test_file, rating_threshold, user_mapping, item_mapping, file_format == ItemDataFileFormat.IGNORE_FIRST_LINE);
}
} else {
// Ensure reproducible splitting
if (random_seed != -1)
org.mymedialite.util.Random.initInstance(random_seed);
PosOnlyFeedback<SparseBooleanMatrix> train = new PosOnlyFeedback<SparseBooleanMatrix>(SparseBooleanMatrix.class);
PosOnlyFeedback<SparseBooleanMatrix> test = new PosOnlyFeedback<SparseBooleanMatrix>(SparseBooleanMatrix.class);
PosOnlyFeedbackSimpleSplit<PosOnlyFeedback<SparseBooleanMatrix>> split = new PosOnlyFeedbackSimpleSplit<PosOnlyFeedback<SparseBooleanMatrix>>(training_data, test_ratio, train, test);
training_data = split.train().get(0);
test_data = split.test().get(0);
}
if (group_method == "GroupsAsUsers") {
System.out.println("Group recommendation strategy: " + group_method);
// TODO verify what is going on here
//var training_data_group = new PosOnlyFeedback<SparseBooleanMatrix>();
// Transform groups to users
for (int group_id : group_to_user.nonEmptyRowIDs())
for (int user_id : group_to_user.get(group_id))
for (int item_id : training_data.userMatrix().getEntriesByRow(user_id))
training_data.add(group_id, item_id);
// Add the users that do not belong to groups
//training_data = training_data_group;
// Transform groups to users
PosOnlyFeedback<SparseBooleanMatrix> test_data_group = new PosOnlyFeedback<SparseBooleanMatrix>(SparseBooleanMatrix.class);
for (int group_id : group_to_user.nonEmptyRowIDs())
for (int user_id : group_to_user.get(group_id))
for (int item_id : test_data.userMatrix().getEntriesByRow(user_id))
test_data_group.add(group_id, item_id);
test_data = test_data_group;
group_method = null; // deactivate s.t. the normal eval routines are used
}
if (user_prediction) {
// swap file names for test users and candidate items
String ruf = test_users_file;
String rif = candidate_items_file;
test_users_file = rif;
candidate_items_file = ruf;
// Swap user and item mappings
IEntityMapping um = user_mapping;
IEntityMapping im = item_mapping;
user_mapping = im;
item_mapping = um;
// Transpose training and test data
training_data = training_data.transpose();
// Transpose test data
if (test_data != null) test_data = test_data.transpose();
}
if (recommender instanceof org.mymedialite.itemrec.ItemRecommender)
((ItemRecommender)recommender).setFeedback(training_data);
// Test users
if (test_users_file != null)
test_users = new ArrayList<Integer>(user_mapping.toInternalID(NumberFile.readStrings(Utils.combine(data_dir, test_users_file))));
else
test_users = test_data != null ? test_data.allUsers() : training_data.allUsers();
// If necessary, perform user sampling
if (num_test_users > 0 && num_test_users < test_users.size()) {
// Ensure reproducible splitting
if (random_seed != -1)
org.mymedialite.util.Random.initInstance(random_seed);
IntArraySet old_test_users = new IntArraySet(test_users);
List<Integer> new_test_users = new ArrayList<Integer>(num_test_users);
for (int i = 0; i < num_test_users; i++) {
int random_index = org.mymedialite.util.Random.getInstance().nextInt(old_test_users.size() - 1);
new_test_users.add(i, old_test_users.toIntArray()[random_index]);
old_test_users.remove(new_test_users.get(i));
}
test_users = new ArrayList<Integer>(new_test_users);
}
// Candidate items
if (candidate_items_file != null) {
candidate_items = new ArrayList<Integer>(item_mapping.toInternalID(NumberFile.readStrings(Utils.combine(data_dir, candidate_items_file))));
} else if (all_items) {
candidate_items = new ArrayList<Integer>();
for(int id : item_mapping.internalIDs()) candidate_items.add(id);
}
if (candidate_items != null)
eval_item_mode = CandidateItems.EXPLICIT;
else if (in_training_items)
eval_item_mode = CandidateItems.TRAINING;
else if (in_test_items)
eval_item_mode = CandidateItems.TEST;
else if (overlap_items)
eval_item_mode = CandidateItems.OVERLAP;
else
eval_item_mode = CandidateItems.UNION;
System.out.println("Loading time: " + (double)(Calendar.getInstance().getTimeInMillis() - start) / 1000 + " seconds");
System.out.println("Memory usage: " + Memory.getUsage() + " MB");
}
static ItemRecommendationEvaluationResults computeFit() throws Exception {
if (filtered_eval)
return ItemsFiltered.evaluateFiltered(recommender, training_data, training_data, item_attributes, test_users, candidate_items, true);
else
return Items.evaluate(recommender, training_data, training_data, test_users, candidate_items, eval_item_mode, true);
}
static ItemRecommendationEvaluationResults evaluate() throws Exception {
if (filtered_eval)
return ItemsFiltered.evaluateFiltered(recommender, test_data, training_data, item_attributes, test_users, candidate_items, repeat_eval);
else
return Items.evaluate(recommender, test_data, training_data, test_users, candidate_items, eval_item_mode, repeat_eval);
}
static void predict(String prediction_file, String predict_for_users_file, int iteration) throws IOException {
if (prediction_file == null) return;
predict(prediction_file + "-it-" + iteration, predict_for_users_file);
}
static void predict(String prediction_file, String predict_for_users_file) throws IOException {
if (candidate_items == null)
candidate_items = training_data.allItems();
List<Integer> user_list = null;
if (predict_for_users_file != null)
user_list = user_mapping.toInternalID(NumberFile.readStrings(predict_for_users_file));
long start = Calendar.getInstance().getTimeInMillis();
Extensions.writePredictions(
recommender,
training_data,
candidate_items, predict_items_number,
prediction_file, user_list,
user_mapping, item_mapping);
System.err.println("Wrote predictions to " + prediction_file);
System.out.println("Prediction_time " + (Calendar.getInstance().getTimeInMillis() - start) + " milliseconds");
}
static void displayStats() {
if (training_time_stats.size() > 0) {
double max = Collections.max(training_time_stats);
double min = Collections.min(training_time_stats);
double avg = Utils.average(training_time_stats);
System.out.println("Iteration time: min=" + min + ", max=" + max + ", avg=" + avg + " seconds");
}
if (eval_time_stats.size() > 0) {
double max = Collections.max(eval_time_stats);
double min = Collections.min(eval_time_stats);
double avg = Utils.average(eval_time_stats);
System.out.println("Evaluation time: min=" + min + ", max=" + max + ", avg=" + avg + " seconds");
}
if (compute_fit && fit_time_stats.size() > 0) {
double max = Collections.max(fit_time_stats);
double min = Collections.min(fit_time_stats);
double avg = Utils.average(fit_time_stats);
System.out.println("Fit time: min=" + min + ", max=" + max + ", avg=" + avg + " seconds");
}
System.out.println("Memory usage: " + Memory.getUsage() + " MB");
}
}