// Copyright (C) 2010 Steffen Rendle, Zeno Gantner
// Copyright (C) 2011, 2012 Zeno Gantner, Chris Newell
//
// This file is part of MyMediaLite.
//
// MyMediaLite is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// MyMediaLite is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>.
package org.mymedialite.ratingprediction;
import java.util.List;
import org.mymedialite.datatype.MatrixExtensions;
import org.mymedialite.datatype.VectorExtensions;
/**
* Matrix factorization with explicit user and item bias, learning is performed by
* stochastic gradient descent, optimized for the log likelihood.
*
* Implements a simple version Menon and Elkan's LFL model:
* Predicts binary labels, no advanced regUlarization, no side information.
*
* Literature:
*
* Aditya Krishna Menon, Charles Elkan:
* A log-linear model with latent features for dyadic prediction.
* ICDM 2010.
* http://cseweb.ucsd.edu/~akmenon/LFL-ICDM10.pdf
*
* This recommender supports incremental updates.
* @version 2.03
*/
public class LogisticRegressionMatrixFactorization extends BiasedMatrixFactorization {
/**
*
*/
protected void iterate(List<Integer> rating_indices, boolean update_user, boolean update_item) {
double rating_range_size = maxRating - minRating;
for (int index : rating_indices) {
int u = ratings.users().get(index);
int i = ratings.items().get(index);
double dot_product = userBias[u] + itemBias[i] + MatrixExtensions.rowScalarProduct(userFactors, u, itemFactors, i);
double sig_dot = 1 / (1 + Math.exp(-dot_product));
double prediction = minRating + sig_dot * rating_range_size;
double gradient_common = ratings.get(index) - prediction;
// Adjust biases
if (update_user)
userBias[u] = userBias[u] + learnRate * (gradient_common - biasReg * userBias[u]);
if (update_item)
itemBias[i] = learnRate * (gradient_common - biasReg * itemBias[i]);
// Adjust latent factors
for (int f = 0; f < numFactors; f++) {
double u_f = userFactors.get(u, f);
double i_f = itemFactors.get(i, f);
if (update_user) {
double delta_u = gradient_common * i_f - regU * u_f;
MatrixExtensions.inc(userFactors, u, f, learnRate * delta_u);
// this is faster (190 vs. 260 seconds per iteration on Netflix w/ k=30) than
// userFactorsu, f) += learn_rate * delta_u;
}
if (update_item) {
double delta_i = gradient_common * u_f - regI * i_f;
MatrixExtensions.inc(itemFactors, i, f, learnRate * delta_i);
}
}
}
}
/**
*
*/
public double computeLoss() {
double rating_range_size = maxRating - minRating;
double loss = 0;
for (int i = 0; i < ratings.size(); i++) {
double prediction = predict(ratings.users().get(i), ratings.items().get(i));
// Map into [0, 1] interval
prediction = (prediction - minRating) / rating_range_size;
if (prediction < 0.0)
prediction = 0.0;
if (prediction > 1.0)
prediction = 1.0;
double actual_rating = (ratings.get(i) - minRating) / rating_range_size;
loss -= (actual_rating) * Math.log(prediction);
loss -= (1 - actual_rating) * Math.log(1 - prediction);
}
double complexity = 0;
for (int u = 0; u <= maxUserID; u++) {
complexity += ratings.countByUser().get(u) * regU * Math.pow(VectorExtensions.euclideanNorm(userFactors.getRow(u)), 2);
complexity += ratings.countByUser().get(u) * biasReg * Math.pow(userBias[u], 2);
}
for (int i = 0; i <= maxItemID; i++) {
complexity += ratings.countByItem().get(i) * regI * Math.pow(VectorExtensions.euclideanNorm(itemFactors.getRow(i)), 2);
complexity += ratings.countByItem().get(i) * biasReg * Math.pow(itemBias[i], 2);
}
return loss + complexity;
}
public String toString() {
return
this.getClass().getName()
+ " numFactors=" + numFactors
+ " biasReg=" + biasReg
+ " regI=" + regI
+ " regU=" + regU
+ " learnRate=" + learnRate
+ " numIter=" + numIter
+ " boldDriver=" + boldDriver
+ " initMean=" + initMean
+ " initStDev=" + initStDev;
}
}