// Copyright (C) 2011 Zeno Gantner, Chris Newell
//
// This file is part of MyMediaLite.
//
// MyMediaLite is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// MyMediaLite is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>.
package org.mymedialite.ratingprediction;
import java.util.Date;
import org.mymedialite.datatype.MatrixExtensions;
import org.mymedialite.datatype.Pair;
import org.mymedialite.datatype.SparseMatrix;
import org.mymedialite.datatype.SparseMatrixExtensions;
import org.mymedialite.datatype.VectorExtensions;
/**
* Time-aware bias model with frequencies.
*
* Model described in equation (11) of BellKor Grand Prize documentation for the Netflix Prize (see below).
*
* The default hyper-parameter values are set to the ones shown in the report.
* For datasets other than Netflix, you may want to find better parameters.
*
* Literature:
* Yehuda Koren: The BellKor Solution to the Netflix Grand Prize
*
* This recommender does currently NOT support incremental updates.
* @version 2.03
*/
public class TimeAwareBaselineWithFrequencies extends TimeAwareBaseline {
// Additional parameters
SparseMatrix<Double> item_bias_at_frequency;
// Additional hyper-parameters
/**
* logarithmic base for the frequency counts.
*/
public double FrequencyLogBase;
/**
* Regularization constant for b_{i, f_{ui}}.
*/
public double RegItemBiasAtFrequency;
/**
* Learn rate for b_{i, f_{ui}}.
*/
public double ItemBiasAtFrequencyLearnRate;
// Additional helper data structures
SparseMatrix<Integer> log_frequency_by_day;
/**
* Default constructor.
*/
public TimeAwareBaselineWithFrequencies() {
numIter = 40;
FrequencyLogBase = 6.76;
binSize = 70;
beta = 0.4;
userBiasLearnRate = 0.00267;
itemBiasLearnRate = 0.000488;
alphaLearnRate = 0.00000311;
itemBiasByTimeBinLearnRate = 0.00000115;
userBiasByDayLearnRate = 0.000257;
userScalingLearnRate = 0.00564;
userScalingByDayLearnRate = 0.00103;
ItemBiasAtFrequencyLearnRate = 0.00236;
regU = 0.0255;
regI = 0.0255;
regAlpha = 3.95;
regItemBiasByTimeBin = 0.0929;
regUserBiasByDay = 0.00231;
regUserScaling = 0.0476;
regUserScalingByDay = 0.019;
RegItemBiasAtFrequency = 0.000000011;
}
/**
*/
public void train() {
// TODO check for better way to do this.
int number_of_days = (int)((timed_ratings.latestTime().getTime() - timed_ratings.earliestTime().getTime()) / (24 * 3600000));
// Compute log rating frequencies
log_frequency_by_day = new SparseMatrix<Integer>(maxUserID + 1, number_of_days, 0);
// First count the frequencies ...
for (int i = 0; i < timed_ratings.size(); i++) {
int day = relativeDay(timed_ratings.times().get(i));
log_frequency_by_day.set(timed_ratings.users().get(i), day, log_frequency_by_day.get(timed_ratings.users().get(i), day) + 1);
}
// ... then apply (rounded) logarithm
for (Pair<Integer, Integer> index_pair : log_frequency_by_day.nonEmptyEntryIDs())
log_frequency_by_day.set(index_pair.first, index_pair.second,
(int) Math.ceil(Math.log(log_frequency_by_day.get(index_pair.first, index_pair.second)) / Math.log(FrequencyLogBase)));
super.train();
}
/**
*/
protected void initModel() {
super.initModel();
item_bias_at_frequency = new SparseMatrix<Double>(maxItemID + 1, SparseMatrixExtensions.maxInteger(log_frequency_by_day), 0.0);
}
/**
*
*/
protected void updateParameters(int u, int i, int day, int bin, double err) {
super.updateParameters(u, i, day, bin, err);
// Update additional bias
int f = log_frequency_by_day.get(u, day);
double b_i_f_ui = item_bias_at_frequency.get(i, f);
item_bias_at_frequency.set(i, f, item_bias_at_frequency.get(i, f) + 2 * ItemBiasAtFrequencyLearnRate * (err * b_i_f_ui - RegItemBiasAtFrequency * b_i_f_ui));
}
/**
*
*/
protected double predict(int user_id, int item_id, int day, int bin) {
double result = super.predict(user_id, item_id, day, bin);
// TODO should this be relative day?
if (day <= timed_ratings.latestTime().getDay())
result += item_bias_at_frequency.get(item_id, log_frequency_by_day.get(user_id, day));
return result;
}
/**
*
*/
public double predict(int user_id, int item_id, Date time) {
double result = super.predict(user_id, item_id, time);
int day = relativeDay(time);
// TODO should this be relative day?
if (day <= timed_ratings.latestTime().getDay())
result += item_bias_at_frequency.get(item_id, log_frequency_by_day.get(user_id, day));
return result;
}
/**
*
*/
public double computeLoss() {
return super.computeLoss()
+ RegItemBiasAtFrequency * Math.pow(SparseMatrixExtensions.frobeniusNorm(item_bias_at_frequency), 2);
}
/**
*
*/
public String toString() {
return "TimeAwareBaselineWithFrequencies"
+ " num_iter=" + numIter
+ " bin_size=" + binSize
+ " beta=" + beta
+ " user_bias_learn_rate=" + userBiasLearnRate
+ " item_bias_learn_rate=" + itemBiasLearnRate
+ " alpha_learn_rate=" + alphaLearnRate
+ " item_bias_by_time_bin_learn_rate=" + itemBiasByTimeBinLearnRate
+ " user_bias_by_day_learn_rate=" + userBiasByDayLearnRate
+ " user_scaling_learn_rate=" + userScalingLearnRate
+ " user_scaling_by_day_learn_rate=" + userScalingByDayLearnRate
+ " reg_u=" + regU
+ " reg_i=" + regI
+ " reg_alpha=" + regAlpha
+ " reg_item_bias_by_time_bin=" + regItemBiasByTimeBin
+ " reg_user_bias_by_day=" + regUserBiasByDay
+ " reg_user_scaling=" + regUserScaling
+ " reg_user_scaling_by_day=" + regUserScalingByDay
+ " frequencyLogBase" + FrequencyLogBase
+ " itemBiasAtFrequencyLearnRate" + ItemBiasAtFrequencyLearnRate
+ " regItemBiasAtFrequency" + RegItemBiasAtFrequency;
}
}