//Copyright (C) 2011 Zeno Gantner, Chris Newell // //This file is part of MyMediaLite. // //MyMediaLite is free software: you can redistribute it and/or modify //it under the terms of the GNU General Public License as published by //the Free Software Foundation, either version 3 of the License, or //(at your option) any later version. // //MyMediaLite is distributed in the hope that it will be useful, //but WITHOUT ANY WARRANTY; without even the implied warranty of //MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //GNU General Public License for more details. // //You should have received a copy of the GNU General Public License //along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>. // package org.mymedialite.ratingprediction; import java.io.BufferedReader; import java.io.IOException; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Calendar; import java.util.Date; import java.util.List; import org.mymedialite.IIterativeModel; import org.mymedialite.datatype.Matrix; import org.mymedialite.datatype.MatrixExtensions; import org.mymedialite.datatype.SparseMatrix; import org.mymedialite.datatype.SparseMatrixExtensions; import org.mymedialite.datatype.VectorExtensions; import org.mymedialite.eval.Ratings; /** * Time-aware bias model. * * Model described in equation (10) of BellKor Grand Prize documentation for the Netflix Prize (see below). * The optimization problem is described in equation (12). * * The default hyper-parameter values are set to the ones shown in the report. * For datasets other than Netflix, you may want to find better parameters. * * Literature: * Yehuda Koren: The BellKor Solution to the Netflix Grand Prize * * This recommender does currently NOT support incremental updates. * @version 2.03 */ public class TimeAwareBaseline extends TimeAwareRatingPredictor implements IIterativeModel { // Parameters double global_average; List<Double> user_bias; List<Double> item_bias; List<Double> alpha; Matrix<Double> item_bias_by_time_bin; // items : rows, bins : columns SparseMatrix<Double> user_bias_by_day; // users : rows, days : columns List<Double> user_scaling; // c_u SparseMatrix<Double> user_scaling_by_day; // c_ut // Hyperparameters /** * Number of iterations over the dataset to perform. */ public int numIter; /** * Bin size in days for modeling the time-dependent item bias. */ public int binSize; /** * Beta parameter for modeling the drift in the user bias. */ public double beta; // Parameter-specific learn rates /** * Learn rate for the user bias. */ public double userBiasLearnRate; /** * Learn rate for the item bias. */ public double itemBiasLearnRate; /** * Learn rate for the user-wise alphas. */ public double alphaLearnRate; /** * Learn rate for the bin-wise item bias. */ public double itemBiasByTimeBinLearnRate; /** * Learn rate for the day-wise user bias. */ public double userBiasByDayLearnRate; /** * Learn rate for the user-wise scaling factor. */ public double userScalingLearnRate; /** * Learn rate for the day-wise user scaling factor. */ public double userScalingByDayLearnRate; // Parameter-specific regularization constants /** * Regularization for the user bias. */ public double regU; /** * Regularization for the item bias. */ public double regI; /** * Regularization for the user-wise alphas. */ public double regAlpha; /** * Regularization for the bin-wise item bias. */ public double regItemBiasByTimeBin; /** * Regularization for the day-wise user bias. */ public double regUserBiasByDay; /** * Regularization for the user scaling factor. */ public double regUserScaling; /** * Regularization for the day-wise user scaling factor. */ public double regUserScalingByDay; // Helper data structures List<Double> userMeanDay; /** * Default constructor. */ public TimeAwareBaseline() { numIter = 30; binSize = 70; beta = 0.4; userBiasLearnRate = 0.003; itemBiasLearnRate = 0.002; alphaLearnRate = 0.00001; itemBiasByTimeBinLearnRate = 0.000005; userBiasByDayLearnRate = 0.0025; userScalingLearnRate = 0.008; userScalingByDayLearnRate = 0.002; regU = 0.03; regI = 0.03; regAlpha = 50; regItemBiasByTimeBin = 0.1; regUserBiasByDay = 0.005; regUserScaling = 0.01; regUserScalingByDay = 0.005; } @Override public void setNumIter(int numIter) { this.numIter = numIter; } @Override public int getNumIter() { return numIter; } /** */ public void train() { initModel(); global_average = ratings.average(); // Compute mean day of rating by user userMeanDay = new ArrayList<Double>(maxUserID + 1); for(int i = 0; i <= maxUserID; i++) userMeanDay.add(0.0); for (int i = 0; i < timed_ratings.size(); i++) userMeanDay.set(ratings.users().get(i), userMeanDay.get(ratings.users().get(i)) + relativeDay(timed_ratings.times().get(i))); for (int u = 0; u <= maxUserID; u++) if (ratings.countByUser().get(u) != 0) userMeanDay.set(u, userMeanDay.get(u) / ratings.countByUser().get(u)); else // no ratings yet? userMeanDay.set(u, new Double(relativeDay(timed_ratings.latestTime()))); // set to latest day for (int i = 0; i < numIter; i++) iterate(); } /** * Given a Date object, return the day relative to the first rating day in the dataset. * @return the day relative to the first rating day in the dataset * @param date the date/time of the rating event */ protected int relativeDay(Date date) { return (int)((date.getTime() - timed_ratings.earliestTime().getTime()) / (24 * 3600000)); } /** * Initialize the model parameters. */ protected void initModel() { int number_of_days = (int)((timed_ratings.latestTime().getTime() - timed_ratings.earliestTime().getTime()) / (24 * 3600000)); int number_of_bins = number_of_days / binSize + 1; System.out.println(number_of_days + " days, " + number_of_bins + " bins"); // Initialize parameters user_bias = new ArrayList<Double>(maxUserID + 1); item_bias = new ArrayList<Double>(maxItemID + 1); alpha = new ArrayList<Double>(maxUserID + 1); item_bias_by_time_bin = new Matrix<Double>(maxItemID + 1, number_of_bins, 0.0); user_bias_by_day = new SparseMatrix<Double>(maxUserID + 1, number_of_days, 0.0); user_scaling = new ArrayList<Double>(maxUserID + 1); user_scaling_by_day = new SparseMatrix<Double>(maxUserID + 1, number_of_days, 0.0); for (int i = 0; i <= maxUserID; i++) { user_bias.add(0.0); item_bias.add(0.0); alpha.add(0.0); user_scaling.add(0.0); } } /** */ public void iterate() { for (int index : timed_ratings.randomIndex()) { int u = timed_ratings.users().get(index); int i = timed_ratings.items().get(index); int day = relativeDay(timed_ratings.times().get(index)); int bin = day / binSize; // Compute error double err = timed_ratings.get(index) - predict(u, i, day, bin); updateParameters(u, i, day, bin, err); } } /** * Single SGD step: update the parameter values for one user and one item. * <param name='u'>the user ID * <param name='i'>the item ID * <param name='day'>the day of the rating * <param name='bin'>the day bin of the rating * <param name='err'>the current error made for this rating */ protected void updateParameters(int u, int i, int day, int bin, double err) { // Update user biases double dev_u = Math.signum(day - userMeanDay.get(u)) * Math.pow(Math.abs(day - userMeanDay.get(u)), beta); alpha.set(u, alpha.get(u) + 2 * alphaLearnRate * (err * dev_u - regAlpha * alpha.get(u))); user_bias.set(u, user_bias.get(u) + 2 * userBiasLearnRate * (err - regU * user_bias.get(u))); user_bias_by_day.set(u, day, user_bias_by_day.get(u, day) + 2 * userBiasByDayLearnRate * (err - regUserBiasByDay * user_bias_by_day.get(u, day))); // Update item biases and user scalings double b_i = item_bias.get(i); double b_ib = item_bias_by_time_bin.get(i, bin); double c_u = user_scaling.get(u); double c_ud = user_scaling_by_day.get(u, day); item_bias.set(i, item_bias.get(i) + 2 * itemBiasLearnRate * (err * (c_u + c_ud) - regI * b_i)); item_bias_by_time_bin.set(i, bin, item_bias_by_time_bin.get(i, bin) + 2 * itemBiasByTimeBinLearnRate * (err * (c_u + c_ud) - regItemBiasByTimeBin * b_ib)); user_scaling.set(u, user_scaling.get(u) + 2 * userScalingLearnRate * (err * (b_i + b_ib) - regUserScaling * (c_u - 1))); user_scaling_by_day.set(u, day, user_scaling_by_day.get(u, day) + 2 * userScalingByDayLearnRate * (err * (b_i + b_ib) - regUserScalingByDay * c_ud)); } /** */ public double predict(int user_id, int item_id) { double result = global_average; if (user_id <= maxUserID) result += user_bias.get(user_id); if (item_id <= maxItemID) result += item_bias.get(item_id); return result; } /** * Predict the specified user_id, item_id, day and bin. * * Assumes user and item IDs are valid. * * <param name='user_id'>the user ID * <param name='item_id'>the item ID * <param name='day'>the day of the rating * <param name='bin'>the day bin of the rating */ protected double predict(int user_id, int item_id, int day, int bin) { double result = global_average; double dev_u = Math.signum(day - userMeanDay.get(user_id)) * Math.pow(Math.abs(day - userMeanDay.get(user_id)), beta); result += user_bias.get(user_id) + alpha.get(user_id) * dev_u + user_bias_by_day.get(user_id, day); result += (item_bias.get(item_id) + item_bias_by_time_bin.get(item_id, bin)) ; // * (user_scaling.get(user_id) + user_scaling_by_day.get(user_id, day)); return result; } /** * */ @SuppressWarnings("deprecation") public double predict(int user_id, int item_id, Date time) { int day = relativeDay(time); int bin = day / binSize; // Use latest day bin if the rating time is after the training time period if (bin >= item_bias_by_time_bin.numberOfColumns()) bin = item_bias_by_time_bin.numberOfColumns() - 1; double result = global_average; if (user_id <= maxUserID) { double dev_u = Math.signum(day - userMeanDay.get(user_id)) * Math.pow(Math.abs(day - userMeanDay.get(user_id)), beta); result += user_bias.get(user_id) + alpha.get(user_id) * dev_u; if (day <= timed_ratings.latestTime().getDay()) result += user_bias_by_day.get(user_id, day); } if (item_id <= maxItemID && user_id > maxUserID) result += item_bias.get(item_id) + item_bias_by_time_bin.get(item_id, bin); if (item_id <= maxItemID && user_id <= maxUserID && day < user_scaling_by_day.numberOfColumns()) result += (item_bias.get(item_id) + item_bias_by_time_bin.get(item_id, bin)) * (user_scaling.get(user_id) + user_scaling_by_day.get(user_id, day)); return result; } /** * */ public double computeLoss() { double loss = 2 * Ratings.evaluate(this, ratings).get("RMSE") + regU * Math.pow(VectorExtensions.euclideanNorm(user_bias), 2) + regI * Math.pow(VectorExtensions.euclideanNorm(item_bias), 2) + regAlpha * Math.pow(VectorExtensions.euclideanNorm(alpha), 2) + regUserBiasByDay * Math.pow(SparseMatrixExtensions.frobeniusNorm(user_bias_by_day), 2) + regItemBiasByTimeBin * Math.pow(MatrixExtensions.frobeniusNorm(item_bias_by_time_bin) , 2) + regUserScalingByDay * Math.pow(SparseMatrixExtensions.frobeniusNorm(user_scaling_by_day), 2); double user_scaling_reg_term = 0; for (double e : user_scaling) user_scaling_reg_term += Math.pow(1 - e, 2); user_scaling_reg_term = user_scaling_reg_term * regUserScaling; loss += user_scaling_reg_term; return loss; } /** * */ public String toString() { return "TimeAwareBaseline" + " num_iter=" + numIter + " bin_size=" + binSize + " beta=" + beta + " user_bias_learn_rate=" + userBiasLearnRate + " item_bias_learn_rate=" + itemBiasLearnRate + " alpha_learn_rate=" + alphaLearnRate + " item_bias_by_time_bin_learn_rate=" + itemBiasByTimeBinLearnRate + " user_bias_by_day_learn_rate=" + userBiasByDayLearnRate + " user_scaling_learn_rate=" + userScalingLearnRate + " user_scaling_by_day_learn_rate=" + userScalingByDayLearnRate + " reg_u=" + regU + " reg_i=" + regI + " reg_alpha=" + regAlpha + " reg_item_bias_by_time_bin=" + regItemBiasByTimeBin + " reg_user_bias_by_day=" + regUserBiasByDay + " reg_user_scaling=" + regUserScaling + " reg_user_scaling_by_day=" + regUserScalingByDay; } @Override public void saveModel(String filename) throws IOException { throw new UnsupportedOperationException(); } @Override public void saveModel(PrintWriter writer) throws IOException { throw new UnsupportedOperationException(); } @Override public void loadModel(String filename) throws IOException { throw new UnsupportedOperationException(); } @Override public void loadModel(BufferedReader reader) throws IOException { throw new UnsupportedOperationException(); } }