// Copyright (C) 2010 Zeno Gantner, Christoph Freudenthaler // Copyright (C) 2011 Zeno Gantner, Chris Newell // // This file is part of MyMediaLite. // // MyMediaLite is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // MyMediaLite is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>. package org.mymedialite.itemrec; import it.unimi.dsi.fastutil.ints.IntArraySet; import it.unimi.dsi.fastutil.ints.IntCollection; import it.unimi.dsi.fastutil.ints.IntSet; import java.io.*; import java.util.*; import org.mymedialite.datatype.*; import org.mymedialite.io.IMatrixExtensions; import org.mymedialite.io.Model; import org.mymedialite.util.*; import org.mymedialite.util.Random; /** * Matrix factorization model for item prediction (ranking) optimized using BPR. * * BPR reduces ranking to pairwise classification. * * Literature: * Steffen Rendle, Christoph Freudenthaler, Zeno Gantner, Lars Schmidt-Thieme: * BPR: Bayesian Personalized Ranking from Implicit Feedback. * UAI 2009. * http://www.ismll.uni-hildesheim.de/pub/pdfs/Rendle_et_al2009-Bayesian_Personalized_Ranking.pdf * * Different sampling strategies are configurable by setting the uniformUserSampling and withReplacement accordingly. * To get the strategy from the original paper, set uniformUserSampling=false and withReplacement=false. * withReplacement=true (default) gives you usually a slightly faster convergence, and uniformUserSampling=true (default) * (approximately) optimizes the average AUC over all users. * * This recommender supports incremental updates. * @version 2.03 */ public class BPRMF extends MF { private static final String VERSION = "2.03"; /** Fast, but memory-intensive sampling */ protected boolean fastSampling = false; /** Item bias terms */ protected double[] itemBias; /** * Fast sampling memory limit, in MiB * * TODO find out why fast sampling does not improve performance */ //public int fastSamplingMemoryLimit = 1200; public int fastSamplingMemoryLimit = 0; /** Sample positive observations with (true) or without (false) replacement */ public boolean withReplacement = false; /** Sample uniformly from users */ public boolean uniformUserSampling; /** Regularization parameter for the bias term */ public double biasReg; /** Learning rate alpha */ public double learnRate = 0.05; /** Regularization parameter for user factors */ public double regU = 0.0025; /** Regularization parameter for positive item factors */ public double regI = 0.0025; /** Regularization parameter for negative item factors */ public double regJ = 0.00025; /** If set (default), update factors for negative sampled items during learning */ protected boolean updateJ = true; /** Support data structure for fast sampling */ protected ArrayList<int[]> userPosItems; /** Support data structure for fast sampling */ protected ArrayList<int[]> userNegItems; /** * Use bold driver heuristics for learning rate adaption. * See * Rainer Gemulla, Peter J. Haas, Erik Nijkamp, Yannis Sismanis: * Large-Scale Matrix Factorization with Distributed Stochastic Gradient Descent * 2011 */ public boolean boldDriver = false; /** * Loss for the last iteration, used by bold driver heuristics. */ double lastLoss = Double.NEGATIVE_INFINITY; /** * Array of user components of triples to use for approximate loss computation. */ int[] lossSampleU; /** * Array of positive item components of triples to use for approximate loss computation. */ int[] lossSampleI; /** * Array of negative item components of triples to use for approximate loss computation. */ int[] lossSampleJ; /** Random number generator */ protected org.mymedialite.util.Random random = Random.getInstance(); /** * Default constructor. */ public BPRMF() { uniformUserSampling = true; } protected void initModel() { super.initModel(); itemBias = new double[maxItemID + 1]; } public void train() { initModel(); checkSampling(); random = Random.getInstance(); if (boldDriver) { int num_sample_triples = (int) Math.sqrt(maxUserID) * 100; System.err.println("loss_num_sample_triples=" + num_sample_triples); // create the sample to estimate loss from lossSampleU = new int[num_sample_triples]; lossSampleI = new int[num_sample_triples]; lossSampleJ = new int[num_sample_triples]; for (int c = 0; c < num_sample_triples; c++) { SampleTriple triple = sampleTriple(); lossSampleU[c] = triple.u; lossSampleI[c] = triple.i; lossSampleJ[c] = triple.j; } lastLoss = computeLoss(); } for (int i = 0; i < numIter; i++) iterate(); } /** * Perform one iteration of stochastic gradient ascent over the training data. * One iteration is iteration_length * number of entries in the training matrix */ public void iterate() { int num_pos_events = feedback.size(); int user_id, pos_item_id, neg_item_id; if (uniformUserSampling) { if (withReplacement) { // Case 1: uniform user sampling, with replacement IBooleanMatrix user_matrix = feedback.getUserMatrixCopy(); for (int i = 0; i < num_pos_events; i++) { while (true) { // Sampling with replacement user_id = sampleUser(); IntCollection user_items = user_matrix.get(user_id); // Reset user if already exhausted if (user_items.size() == 0) for (int item_id : feedback.userMatrix().get(user_id)) user_matrix.set(user_id, item_id, true); pos_item_id = user_items.toIntArray()[random.nextInt(user_items.size())]; user_matrix.set(user_id, pos_item_id, false); // temporarily forget positive observation do neg_item_id = random.nextInt(maxItemID + 1); while (feedback.userMatrix().get(user_id).contains(neg_item_id)); break; } SampleTriple triple = new SampleTriple(user_id, pos_item_id, neg_item_id); updateFactors(triple, true, true, updateJ); } } else { // Case 2: uniform user sampling, without replacement for (int i = 0; i < num_pos_events; i++) { SampleTriple triple = sampleTriple(); updateFactors(triple, true, true, true); } } } else { if (withReplacement) { // Case 3: uniform pair sampling, with replacement for (int i = 0; i < num_pos_events; i++) { int index = random.nextInt(num_pos_events); user_id = feedback.users().get(index); pos_item_id = feedback.items().get(index); neg_item_id = -1; SampleTriple triple = new SampleTriple(user_id, pos_item_id, neg_item_id); sampleOtherItem(triple); updateFactors(triple, true, true, updateJ); } } else { // Case 4: uniform pair sampling, without replacement for (int index : feedback.randomIndex()) { user_id = feedback.users().get(index); pos_item_id = feedback.items().get(index); neg_item_id = -1; SampleTriple triple = new SampleTriple(user_id, pos_item_id, neg_item_id); sampleOtherItem(triple); updateFactors(triple, true, true, updateJ); } } } if (boldDriver) { double loss = computeLoss(); if (loss > lastLoss) learnRate *= 0.5; else if (loss < lastLoss) learnRate *= 1.1; lastLoss = loss; System.err.println("loss: " + loss + " learnRate: " + learnRate); } } /** * Sample another item, given the first one and the user * @param triple a SampleTriple consisting of a user ID and two item IDs * @return true if the given item was already seen by the user */ protected boolean sampleOtherItem(SampleTriple triple) { boolean itemIsPositive = feedback.userMatrix().get(triple.u, triple.i); if (fastSampling) { if (itemIsPositive) { int rindex = random.nextInt(userNegItems.get(triple.u).length); triple.j = userNegItems.get(triple.u)[rindex]; } else { int rindex = random.nextInt(userPosItems.get(triple.u).length); triple.j = userPosItems.get(triple.u)[rindex]; } } else { do triple.j = random.nextInt(maxItemID + 1); while (feedback.userMatrix().get(triple.u, triple.j) != itemIsPositive); } return itemIsPositive; } /** * Sample a pair of items, given a user * @param triple a SampleTriple consisting of a user ID and two item IDs */ protected void sampleItemPair(SampleTriple triple) { if (fastSampling) { int rindex = random.nextInt(userPosItems.get(triple.u).length); triple.i = userPosItems.get(triple.u)[rindex]; rindex = random.nextInt (userNegItems.get(triple.u).length); triple.j = userNegItems.get(triple.u)[rindex]; } else { IntCollection user_items = feedback.userMatrix().get(triple.u); triple.i = user_items.toIntArray()[random.nextInt(user_items.size())]; do triple.j = random.nextInt (maxItemID + 1); //while (feedback.userMatrix().get(triple.u, triple.j)); while (user_items.contains(triple.j)); } } /** * Sample a user that has viewed at least one and not all items. * @return the user ID */ protected int sampleUser() { while (true) { int u = random.nextInt(maxUserID + 1); IntCollection user_items = feedback.userMatrix().get(u); if (user_items.size() == 0 || user_items.size() == maxItemID + 1) continue; return u; } } /** * Sample a triple for BPR learning. * @return a SampleTriple consisting of a user ID and two item IDs */ protected SampleTriple sampleTriple() { SampleTriple triple = new SampleTriple(); triple.u = sampleUser(); sampleItemPair(triple); return triple; } /** * Update features according to the stochastic gradient descent update rule. * @param t a SampleTriple specifying the user ID and the first and second itemIDs * @param updateU if true, update the user features * @param updateI if true, update the features of the first item * @param updateJ if true, update the features of the second item */ protected void updateFactors(SampleTriple t, boolean updateU, boolean updateI, boolean updateJ) { double x_uij = itemBias[t.i] - itemBias[t.j] + MatrixExtensions.rowScalarProductWithRowDifference(userFactors, t.u, itemFactors,t.i, itemFactors, t.j); double one_over_one_plus_ex = 1 / (1 + Math.exp(x_uij)); // Adjust bias terms. if (updateI) { double bias_update = one_over_one_plus_ex - biasReg * itemBias[t.i]; itemBias[t.i] += learnRate * bias_update; } if (updateJ) { double bias_update = -one_over_one_plus_ex - biasReg * itemBias[t.j]; itemBias[t.j] += learnRate * bias_update; } // Adjust factors. for (int f = 0; f < numFactors; f++) { double w_uf = userFactors.get(t.u, f); double h_if = itemFactors.get(t.i, f); double h_jf = itemFactors.get(t.j, f); if (updateU) { double uf_update = (h_if - h_jf) * one_over_one_plus_ex - regU * w_uf; userFactors.set(t.u, f, w_uf + learnRate * uf_update); } if (updateI) { double if_update = w_uf * one_over_one_plus_ex - regI * h_if; itemFactors.set(t.i, f, h_if + learnRate * if_update); } if (updateJ) { double jf_update = -w_uf * one_over_one_plus_ex - regJ * h_jf; itemFactors.set(t.j, f, h_jf + learnRate * jf_update); } } } @Override public void addFeedback(int user_id, int item_id) { super.addFeedback(user_id, item_id); if (fastSampling) createFastSamplingData(user_id); // retrain retrainUser(user_id); // TODO uncomment below or implement isUpdateUser flag. //retrainItem(item_id); } @Override public void addFeedback(int user_id, List<Integer> item_ids) { super.addFeedback(user_id, item_ids); if (fastSampling) createFastSamplingData(user_id); // retrain retrainUser(user_id); // TODO uncomment below or implement isUpdateUser flag. //retrainItem(item_id); } public void removeFeedback(int user_id, int item_id) { super.removeFeedback(user_id, item_id); if (fastSampling) createFastSamplingData(user_id); // retrain retrainUser(user_id); // TODO uncomment below or implement isUpdateUser flag. //retrainItem(item_id); } public void addUser(int user_id) { super.addUser(user_id); userFactors.addRows(user_id + 1); MatrixExtensions.rowInitNormal(userFactors, user_id, initMean, initStDev); } public void addItem(int item_id) { super.addItem(item_id); itemFactors.addRows(item_id + 1); MatrixExtensions.rowInitNormal(itemFactors, item_id, initMean, initStDev); // Create new item bias array double[] itemBias = Arrays.copyOf(this.itemBias, item_id + 1); this.itemBias = itemBias; } public void removeUser(int user_id) { super.removeUser(user_id); if (fastSampling) { userPosItems.set(user_id, null); userNegItems.set(user_id, null); } // set user latent factors to zero userFactors.setRowToOneValue(user_id, 0.0); } public void removeItem(int item_id) { super.removeItem(item_id); // TODO remove from fast sampling data structures // (however: not needed if all feedback events have been removed properly before) // set item latent factors to zero itemFactors.setRowToOneValue(item_id, 0.0); } /** * Retrain the latent factors of a given user</summary> * @param user_id the user ID */ protected void retrainUser(int user_id) { MatrixExtensions.rowInitNormal(userFactors, user_id, initMean, initStDev); IntCollection user_items = feedback.userMatrix().get(user_id); for (int i = 0; i < userFactors.data.length; i++) if (userFactors.data[i] == null) System.err.println("uf: " + i); for (int i = 0; i < itemFactors.data.length; i++) if (itemFactors.data[i] == null) { System.err.print( "if: " + i); System.err.print( " x " + (i / itemFactors.dim1)); System.err.println(" y " + (i % itemFactors.dim1)); } System.err.flush(); for (int i = 0; i < user_items.size(); i++) { SampleTriple triple = new SampleTriple(); triple.u = user_id; sampleItemPair(triple); updateFactors(triple, true, false, false); } } /** * Retrain the latent factors of a given item</summary> * @param item_id the item ID */ protected void retrainItem(int item_id) { MatrixExtensions.rowInitNormal(itemFactors, item_id, initMean, initStDev); int num_pos_events = feedback.userMatrix().numberOfEntries(); int num_item_iterations = num_pos_events / (maxItemID + 1); for (int i = 0; i < num_item_iterations; i++) { // remark: the item may be updated more or less frequently than in the normal from-scratch training SampleTriple triple = new SampleTriple(); triple.u = sampleUser(); triple.i = item_id; boolean item_is_positive = sampleOtherItem(triple); if (item_is_positive) { int j = triple.j; triple.j = triple.i; triple.i = j; } updateFactors(triple, false, false, true); } } /** * Compute approximate loss. * @return the approximate loss */ public double computeLoss() { double ranking_loss = 0; for (int c = 0; c < lossSampleU.length; c++) { double x_uij = predict(lossSampleU[c], lossSampleI[c]) - predict(lossSampleU[c], lossSampleJ[c]); ranking_loss += 1 / (1 + Math.exp(x_uij)); } double complexity = 0; for (int c = 0; c < lossSampleU.length; c++) { complexity += regU * Math.pow(VectorExtensions.euclideanNorm(userFactors.getRow(lossSampleU[c])), 2); complexity += regI * Math.pow(VectorExtensions.euclideanNorm(itemFactors.getRow(lossSampleI[c])), 2); complexity += regJ * Math.pow(VectorExtensions.euclideanNorm(itemFactors.getRow(lossSampleJ[c])), 2); complexity += biasReg * Math.pow(itemBias[lossSampleI[c]], 2); complexity += biasReg * Math.pow(itemBias[lossSampleJ[c]], 2); } return ranking_loss + 0.5 * complexity; } /** * Compute the fit (AUC on training data) * @return the fit */ public double computeFit() { double sum_auc = 0; int num_user = 0; for (int user_id = 0; user_id < maxUserID + 1; user_id++) { int num_test_items = feedback.userMatrix().get(user_id).size(); if (num_test_items == 0) continue; List<Integer> prediction = Extensions.predictItems(this, user_id, maxItemID); int num_eval_items = maxItemID + 1; int num_eval_pairs = (num_eval_items - num_test_items) * num_test_items; int num_correct_pairs = 0; int num_pos_above = 0; // start with the highest weighting item... for (int i = 0; i < prediction.size(); i++) { int item_id = prediction.get(i); if (feedback.userMatrix().get(user_id, item_id)) num_pos_above++; else num_correct_pairs += num_pos_above; } double user_auc = ((double)num_correct_pairs) / num_eval_pairs; sum_auc += user_auc; num_user++; } double auc = sum_auc / num_user; return auc; } protected void createFastSamplingData(int u) { while (u >= userPosItems.size()) userPosItems.add(null); while (u >= userNegItems.size()) userNegItems.add(null); userPosItems.set(u, feedback.userMatrix().get(u).toIntArray()); IntSet neg_list = new IntArraySet(); for (int i=0; i < maxItemID; i++) if (!feedback.userMatrix().get(u).contains(i)) neg_list.add(i); userNegItems.set(u, neg_list.toIntArray()); } protected void checkSampling() { try { int fast_sampling_memory_size = ((maxUserID + 1) * (maxItemID + 1) * 4) / (1024 * 1024); System.out.println("fast_sampling_memory_size=" + fast_sampling_memory_size); if (fast_sampling_memory_size <= fastSamplingMemoryLimit) { fastSampling = true; this.userPosItems = new ArrayList<int[]>(maxUserID + 1); this.userNegItems = new ArrayList<int[]>(maxUserID + 1); for (int u = 0; u < maxUserID + 1; u++) createFastSamplingData(u); } } catch (Exception e) { System.out.println("fast_sampling_memory_size=TOO_MUCH"); // Do nothing - don't use fast sampling } } public double predict(int user_id, int item_id) { if (item_id >= itemBias.length) throw new IllegalArgumentException("item_id is too big: " + item_id + ", itemBias.length: " + itemBias.length); return itemBias[item_id] + MatrixExtensions.rowScalarProduct(userFactors, user_id, itemFactors, item_id); } public void saveModel(String filename) throws IOException { PrintWriter writer = Model.getWriter(filename, this.getClass(), VERSION); saveModel(writer); writer.flush(); writer.close(); } public void saveModel(PrintWriter writer) { IMatrixExtensions.writeMatrix(writer, userFactors); org.mymedialite.io.VectorExtensions.writeVectorArray(writer, itemBias); IMatrixExtensions.writeMatrix(writer, itemFactors); } @Override public void loadModel(String filename) throws IOException { BufferedReader reader = Model.getReader(filename, this.getClass()); loadModel(reader); reader.close(); } @Override public void loadModel(BufferedReader reader) throws IOException { Matrix<Double> user_factors = (Matrix<Double>) IMatrixExtensions.readDoubleMatrix(reader, new Matrix<Double>(0, 0, null)); double[] item_bias = org.mymedialite.io.VectorExtensions.readVectorArray(reader); Matrix<Double> item_factors = (Matrix<Double>) IMatrixExtensions.readDoubleMatrix(reader, new Matrix<Double>(0, 0, null)); if (user_factors.numberOfColumns() != item_factors.numberOfColumns()) throw new IOException("Number of user and item factors must match: " + user_factors.numberOfColumns() + " != " + item_factors.numberOfColumns()); if (item_bias.length != item_factors.dim1) throw new IOException("Number of items must be the same for biases and factors: " + item_bias.length + " != " + item_factors.dim1); this.maxUserID = user_factors.numberOfRows() - 1; this.maxItemID = item_factors.numberOfRows() - 1; System.out.println("Read BPR model users: " + user_factors.numberOfRows() + " items: " + item_factors.numberOfRows() + " factors: " + user_factors.numberOfColumns() + " item_biases: " + item_bias.length); // Assign new model. if (this.numFactors != user_factors.numberOfColumns()) { System.err.println("Set num_factors to " + user_factors.numberOfColumns()); this.numFactors = user_factors.numberOfColumns(); } this.userFactors = user_factors; this.itemBias = item_bias; this.itemFactors = item_factors; random = Random.getInstance(); } public String toString() { return this.getClass().getName() + " numFactors=" + numFactors + " biasReg=" + biasReg + " regU=" + regU + " regI=" + regI + " regJ=" + regJ + " numIter=" + numIter + " learnRate=" + learnRate + " boldDriver=" + boldDriver + " fastSamplingMemoryLimit=" + fastSamplingMemoryLimit + " initMean=" + initMean + " initStDev=" + initStDev ; } private class SampleTriple { int u; // user_id int i; // item_id positive item int j; // item_id negative item SampleTriple() { } SampleTriple(int u, int i, int j) { this.u = u; this.i = i; this.j = j; } } }