// Copyright (C) 2010 Zeno Gantner, Andreas Hoffmann // Copyright (C) 2011 Zeno Gantner, Chris Newell // // This file is part of MyMediaLite. // // MyMediaLite is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // MyMediaLite is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>. package org.mymedialite.ratingprediction; import it.unimi.dsi.fastutil.doubles.DoubleList; import java.io.BufferedReader; import java.io.IOException; import java.io.PrintWriter; import java.util.Arrays; import org.mymedialite.IIterativeModel; import org.mymedialite.io.VectorExtensions; import org.mymedialite.eval.Ratings; import org.mymedialite.io.Model; /** * Baseline method for rating prediction * * Uses the average rating value, plus a regularized user and item bias for prediction. * * The method is described in section 2.1 of * Yehuda Koren: Factor in the Neighbors: Scalable and Accurate Collaborative Filtering, * Transactions on Knowledge Discovery from Data (TKDD), 2009. * * One difference is that we support several iterations of alternating optimization, instead of just one. * * This recommender supports incremental updates. * * @author Zeno Gantner, Andreas Hoffmann * @version 2.03 */ public class UserItemBaseline extends IncrementalRatingPredictor implements IIterativeModel { private static final String VERSION = "2.03"; /** Regularization parameter for the user biases */ public double regU; /** Regularization parameter for the item biases */ public double regI; /** The number of iterations */ public int numIter; /** The global rating average */ protected double globalAverage; /** The user biases */ protected double userBiases[]; /** The item biases */ protected double itemBiases[]; /** Default constructor */ public UserItemBaseline() { super(); regU = 15; regI = 10; numIter = 10; } /** * @return The number of iterations */ public int getNumIter() { return numIter; } /** * @param numIter The number of iterations */ public void setNumIter(int numIter) { this.numIter = numIter; } public void iterate() { optimizeItemBiases(); optimizeUserBiases(); } void optimizeUserBiases() { int[] userRatingsCount = new int[maxUserID + 1]; for (int u = 0; u <= maxUserID; u++) userBiases[u] = 0; for (int index = 0; index < ratings.size(); index++) { userBiases[ratings.users().get(index)] += ratings.get(index) - globalAverage - itemBiases[ratings.items().get(index)]; userRatingsCount[ratings.users().get(index)]++; } for (int u = 0; u < userBiases.length; u++) if (userRatingsCount[u] != 0) userBiases[u] = userBiases[u] / (regU + userRatingsCount[u]); } void optimizeItemBiases() { int[] item_ratings_count = new int[maxItemID + 1]; for (int i = 0; i <= maxItemID; i++) itemBiases[i] = 0; for (int index = 0; index < ratings.size(); index++) { itemBiases[ratings.items().get(index)] += ratings.get(index) - globalAverage - userBiases[ratings.users().get(index)]; item_ratings_count[ratings.items().get(index)]++; } for (int i = 0; i < itemBiases.length; i++) if (item_ratings_count[i] != 0) itemBiases[i] = itemBiases[i] / (regI + item_ratings_count[i]); } @Override public double predict(int userID, int itemID) { double user_bias = (userID <= maxUserID && userID >= 0) ? userBiases[userID] : 0; double item_bias = (itemID <= maxItemID && itemID >= 0) ? itemBiases[itemID] : 0; double result = globalAverage + user_bias + item_bias; if (result > maxRating) result = maxRating; if (result < minRating) result = minRating; return result; } public void train() { userBiases = new double[maxUserID + 1]; itemBiases = new double[maxItemID + 1]; globalAverage = ratings.average(); for (int i = 0; i < numIter; i++) iterate(); } protected void retrainUser(int userID) { if (getUpdateUsers()) { for (int index : ratings.byUser().get(userID)) userBiases[userID] += ratings.get(index) - globalAverage - itemBiases[ratings.items().get(index)]; if (ratings.byUser().get(userID).size() != 0) userBiases[userID] = userBiases[userID] / (regU + ratings.byUser().get(userID).size()); } } protected void retrainItem(int itemID) { if (getUpdateItems()) { for (int index : ratings.byItem().get(itemID)) itemBiases[itemID] += ratings.get(index) - globalAverage; if (ratings.byItem().get(itemID).size() != 0) itemBiases[itemID] = itemBiases[itemID] / (regI + ratings.byItem().get(itemID).size()); } } @Override public void addRating(int userID, int itemID, double rating) { super.addRating(userID, itemID, rating); this.retrainItem(itemID); this.retrainUser(userID); } @Override public void updateRating(int userID, int itemID, double rating) { super.updateRating(userID, itemID, rating); this.retrainItem(itemID); this.retrainUser(userID); } @Override public void removeRating(int userID, int itemID) { super.removeRating(userID, itemID); this.retrainItem(itemID); this.retrainUser(userID); } @Override public void addUser(int userID) { super.addUser(userID); double[] userBiases = new double[this.maxUserID + 1]; userBiases = Arrays.copyOf(this.userBiases, this.userBiases.length); this.userBiases = userBiases; } @Override public void addItem(int itemID) { super.addItem(itemID); double[] itemBiases = new double[this.maxItemID + 1]; itemBiases = Arrays.copyOf(this.itemBiases, this.itemBiases.length); this.itemBiases = itemBiases; } @Override public void saveModel(String filename) throws IOException { PrintWriter writer = Model.getWriter(filename, this.getClass(), VERSION); saveModel(writer); writer.flush(); writer.close(); } @Override public void saveModel(PrintWriter writer) throws IOException { writer.println(globalAverage); VectorExtensions.writeVectorArray(writer, userBiases); VectorExtensions.writeVectorArray(writer, itemBiases); } @Override public void loadModel(String filename) throws IOException { BufferedReader reader = Model.getReader(filename, this.getClass()); loadModel(reader); reader.close(); } @Override public void loadModel(BufferedReader reader) throws IOException { double globalAverage = Double.parseDouble(reader.readLine()); double[] userBiases = VectorExtensions.readVectorArray(reader); double[] itemBiases = VectorExtensions.readVectorArray(reader); reader.close(); this.globalAverage = globalAverage; this.userBiases = userBiases; this.itemBiases = itemBiases; } @Override public double computeLoss() { return Ratings.evaluate(this, ratings).get("RMSE") + regU * Math.pow(org.mymedialite.datatype.VectorExtensions.euclideanNorm(userBiases), 2) + regI * Math.pow(org.mymedialite.datatype.VectorExtensions.euclideanNorm(itemBiases), 2); } @Override public String toString() { return "user-item-baseline regU=" + regU + " regI=" + regI+ " numIter=" + numIter; } }