// Copyright (C) 2011 Zeno Gantner, Chris Newell
//
// This file is part of MyMediaLite.
//
// MyMediaLite is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// MyMediaLite is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>.
package org.mymedialite.ratingprediction;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import org.mymedialite.IIterativeModel;
import org.mymedialite.data.IRatings;
import org.mymedialite.datatype.Matrix;
import org.mymedialite.datatype.MatrixExtensions;
import org.mymedialite.io.IMatrixExtensions;
import org.mymedialite.io.Model;
/**
* Matrix factorization with factor-wise learning.
*
* Similar to the approach described in Simon Funk's seminal blog post: http://sifter.org/~simon/journal/20061211.html
*
* Literature:
* Robert Bell, Yehuda Koren, Chris Volinsky:
* Modeling Relationships at Multiple Scales to Improve Accuracy of Large Recommender Systems,
* ACM Int. Conference on Knowledge Discovery and Data Mining (KDD'07), 2007.
*
* This recommender does NOT support incremental updates.
*/
public class FactorWiseMatrixFactorization extends RatingPredictor implements IIterativeModel {
private static final String VERSION = "2.03";
UserItemBaseline global_effects = new UserItemBaseline();
int num_learned_factors;
@Override
public void setRatings(IRatings ratings) {
super.setRatings(ratings);
global_effects.setRatings(ratings);
}
/**
* Regularization constant for the user bias of the underlying baseline predictor.
*/
public double getRegU() {
return global_effects.regU;
}
public void setRegU(double regU) {
global_effects.regU = regU;
}
/**
* Regularization constant for the item bias of the underlying baseline predictor.
*/
public double getRegI() {
return global_effects.regI;
}
public void setRegI(double regI) {
global_effects.regI = regI;
}
/**
* Matrix containing the latent user factors.
*/
Matrix<Double> userFactors;
/**
* Matrix containing the latent item factors.
*/
Matrix<Double> itemFactors;
/**
* Number of latent factors.
*/
public int numFactors;
/**
* Number of iterations (in this case: number of latent factors).
*/
public int numIter;
@Override
public void setNumIter(int num_iter) {
this.numIter = num_iter;
}
@Override
public int getNumIter() {
return numIter;
}
/**
* Shrinkage parameter.
*
* alpha in the Bell et al. paper
*/
public double shrinkage;
/**
* Sensibility parameter (stopping criterion for parameter fitting).
*
* epsilon in the Bell et al. paper
*/
public double sensibility;
/**
* Mean of the normal distribution used to initialize the factors.
*/
public double initMean;
/**
* Standard deviation of the normal distribution used to initialize the factors.
*/
public double initStDev;
/**
* Default constructor.
*/
public FactorWiseMatrixFactorization() {
super();
// Set default values
shrinkage = 25;
numFactors = 10;
numIter = 10;
sensibility = 0.00001;
initStDev = 0.1;
}
/**
*/
public void train() {
// Init factor matrices
userFactors = new Matrix<Double>(maxUserID + 1, numFactors, 0.0);
itemFactors = new Matrix<Double>(maxItemID + 1, numFactors, 0.0);
// Init+train global effects model
global_effects.setRatings(ratings);
global_effects.train();
// Learn model parameters
num_learned_factors = 0;
for (int i = 0; i < numIter; i++)
iterate();
}
/**
*/
public void iterate() {
if (num_learned_factors >= numFactors)
return;
// Compute residuals
double[] residuals = new double[ratings.size()];
for (int index = 0; index < ratings.size(); index++) {
int u = ratings.users().getInt(index);
int i = ratings.items().getInt(index);
residuals[index] = ratings.get(index) - predict(u, i);
int n_ui = Math.min(ratings.byUser().get(u).size(), ratings.byItem().get(i).size());
residuals[index] *= n_ui / (n_ui + shrinkage);
}
// Initialize new latent factors
MatrixExtensions.columnInitNormal(userFactors, num_learned_factors, initMean, initStDev);
MatrixExtensions.columnInitNormal(itemFactors, num_learned_factors, initMean, initStDev);
// Compute the next factor by solving many least squares problems with one variable each
double err = Double.MAX_VALUE / 2;
double err_old = Double.MAX_VALUE;
while (err / err_old < 1 - sensibility) {
double[] user_factors_update_numerator = new double[maxUserID + 1];
double[] user_factors_update_denominator = new double[maxUserID + 1];
// Compute updates in one pass over the data
for (int index = 0; index < ratings.size(); index++) {
int u = ratings.users().getInt(index);
int i = ratings.items().getInt(index);
user_factors_update_numerator[u] += residuals[index] * itemFactors.get(i, num_learned_factors);
user_factors_update_denominator[u] += itemFactors.get(i, num_learned_factors) * itemFactors.get(i, num_learned_factors);
}
// Update user factors
for (int u = 0; u <= maxUserID; u++)
if (user_factors_update_numerator[u] != 0)
userFactors.set(u, num_learned_factors, user_factors_update_numerator[u] / user_factors_update_denominator[u]);
double[] item_factors_update_numerator = new double[maxItemID + 1];
double[] item_factors_update_denominator = new double[maxItemID + 1];
// Compute updates in one pass over the data
for (int index = 0; index < ratings.size(); index++) {
int u = ratings.users().getInt(index);
int i = ratings.items().getInt(index);
item_factors_update_numerator[i] += residuals[index] * userFactors.get(u, num_learned_factors);
item_factors_update_denominator[i] += userFactors.get(u, num_learned_factors) * userFactors.get(u, num_learned_factors);
}
// Update item factors
for (int i = 0; i <= maxItemID; i++)
if (item_factors_update_numerator[i] != 0)
itemFactors.set(i, num_learned_factors, item_factors_update_numerator[i] / item_factors_update_denominator[i]);
err_old = err;
err = org.mymedialite.eval.Ratings.computeFit(this);
}
num_learned_factors++;
}
/**
* Predict the rating of a given user for a given item.
*
* If the user or the item are not known to the recommender, the global effects prediction is returned.
* To avoid this behavior for unknown entities, use canPredict() to check before.
*
* @param user_id the user ID
* @param item_id the item ID
* @return the predicted rating
*/
public double predict(int user_id, int item_id) {
if (user_id >= userFactors.dim1 || item_id >= itemFactors.dim1)
return global_effects.predict(user_id, item_id);
double result = global_effects.predict(user_id, item_id) + MatrixExtensions.rowScalarProduct(userFactors, user_id, itemFactors, item_id);
if (result > maxRating)
return maxRating;
if (result < minRating)
return minRating;
return result;
}
/**
* @throws IOException
*/
public void saveModel(String filename) throws IOException {
global_effects.saveModel(filename + "-global-effects");
PrintWriter writer = Model.getWriter(filename, this.getClass(), VERSION);
saveModel(writer);
writer.flush();
writer.close();
}
@Override
public void saveModel(PrintWriter writer) throws IOException {
writer.println(num_learned_factors);
IMatrixExtensions.writeMatrix(writer, userFactors);
IMatrixExtensions.writeMatrix(writer, itemFactors);
writer.flush();
writer.close();
}
/**
* @throws IOException
*/
public void loadModel(String filename) throws IOException {
global_effects.loadModel(filename + "-global-effects");
if (ratings != null)
global_effects.setRatings(ratings);
BufferedReader reader = Model.getReader(filename, this.getClass());
loadModel(reader);
reader.close();
}
public void loadModel(BufferedReader reader) throws IOException {
int num_learned_factors = Integer.parseInt(reader.readLine());
Matrix<Double> user_factors = (Matrix<Double>) IMatrixExtensions.readDoubleMatrix(reader, new Matrix<Double>(0, 0));
Matrix<Double> item_factors = (Matrix<Double>) IMatrixExtensions.readDoubleMatrix(reader, new Matrix<Double>(0, 0));
reader.close();
if (user_factors.numberOfColumns() != item_factors.numberOfColumns())
throw new IOException("Number of user and item factors must match: " + user_factors.numberOfColumns() + " != " + item_factors.numberOfColumns());
this.maxUserID = user_factors.numberOfRows() - 1;
this.maxItemID = item_factors.numberOfRows() - 1;
// Assign new model
this.num_learned_factors = num_learned_factors;
if (this.numFactors != user_factors.numberOfColumns()) {
System.err.println("Set num_factors to " + user_factors.numberOfColumns());
this.numFactors = user_factors.numberOfColumns();
}
this.userFactors = user_factors;
this.itemFactors = item_factors;
}
/**
*/
public double computeLoss() {
return -1;
}
public String toString() {
return
this.getClass().getName()
+ " numFactors=" + numFactors
+ " shrinkage=" + shrinkage
+ " sensibility=" + sensibility
+ " initMean=" + initMean
+ " initStDev=" + initStDev
+ " numIter=" + numIter
+ " regU" + getRegU()
+ " regI" + getRegI();
}
}