// Copyright (C) 2010, 2011 Zeno Gantner, Chris Newell
//
// This file is part of MyMediaLite.
//
// MyMediaLite is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// MyMediaLite is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>.
package org.mymedialite.itemrec;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Collection;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntArraySet;
import it.unimi.dsi.fastutil.ints.IntCollection;
import it.unimi.dsi.fastutil.ints.IntList;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.mymedialite.util.Random;
import org.mymedialite.IItemAttributeAwareRecommender;
import org.mymedialite.IIterativeModel;
import org.mymedialite.datatype.Matrix;
import org.mymedialite.datatype.SparseBooleanMatrix;
import org.mymedialite.io.IMatrixExtensions;
import org.mymedialite.io.Model;
/**
* Linear model optimized for BPR.
*
* Literature:
* Zeno Gantner, Lucas Drumond, Christoph Freudenthaler, Steffen Rendle, Lars Schmidt-Thieme:
* Learning Attribute-to-Feature Mappings for Cold-Start Recommendations.
* ICDM 2011.
* http://www.ismll.uni-hildesheim.de/pub/pdfs/Gantner_et_al2010Mapping.pdf
*
* This recommender does NOT support incremental updates.
* @version 2.03
*/
public class BPRLinear extends ItemRecommender implements IItemAttributeAwareRecommender, IIterativeModel {
private static final String VERSION = "2.03";
private SparseBooleanMatrix itemAttributes;
public SparseBooleanMatrix getItemAttributes() {
return itemAttributes;
}
public void setItemAttributes(SparseBooleanMatrix itemAttributes) {
this.itemAttributes = itemAttributes;
this.maxItemID = Math.max(maxItemID, itemAttributes.numberOfRows() - 1);
}
/**
*
*/
@Override
public int numItemAttributes() {
return itemAttributes.numberOfColumns();
}
// Item attribute weights
private Matrix<Double> itemAttributeWeightByUser;
/**
* One iteration = iterationLength * number of entries in the training matrix.
*/
protected int iterationLength = 5;
private Random random;
// Fast, but memory-intensive sampling
private boolean fastSampling = false;
/**
* Number of iterations over the training data.
*/
public int numIter = 10;
/**
* Fast sampling memory limit, in MiB.
*/
public int fastSamplingMemoryLimit = 1024;
/**
* mean of the Gaussian distribution used to initialize the features.
*/
public double initMean = 0;
/**
* standard deviation of the normal distribution used to initialize the features.
*/
public double initStdev = 0.1;
/**
* Learning rate alpha.
*/
public double learnRate = 0.05;
/**
* Regularization parameter.
*/
public double regularization = 0.015;
// Support data structure for fast sampling
private IntList[] userPosItems;
// Support data structure for fast sampling
private IntList[] userNegItems;
@Override
public void setNumIter(int numIter) {
this.numIter = numIter;
}
@Override
public int getNumIter() {
return numIter;
}
/**
*
*/
@Override
public void train() {
random = org.mymedialite.util.Random.getInstance();
// Prepare fast sampling, if necessary
int fast_sampling_memory_size = ((maxUserID + 1) * (maxItemID + 1) * 4) / (1024 * 1024);
System.err.println("fast_sampling_memory_size=" + fast_sampling_memory_size);
if (fast_sampling_memory_size <= fastSamplingMemoryLimit) {
fastSampling = true;
this.userPosItems = new IntArrayList[maxUserID + 1];
this.userNegItems = new IntArrayList[maxUserID + 1];
for (int u = 0; u < maxUserID + 1; u++) {
IntList pos_list = new IntArrayList(feedback.userMatrix().get(u));
userPosItems[u] = pos_list;
IntList neg_list = new IntArrayList();
for (int i = 0; i < maxItemID; i++)
if (!feedback.userMatrix().get(u).contains(i) && feedback.itemMatrix().get(i).size() != 0)
neg_list.add(i);
userNegItems[u] = neg_list;
}
}
itemAttributeWeightByUser = new Matrix<Double>(maxUserID + 1, numItemAttributes(), 0.0);
//itemAttributeWeightByUser.init(0.0);
for (int i = 0; i < numIter; i++)
iterate();
}
/**
*
* Perform one iteration of stochastic gradient ascent over the training data.
* One iteration is <see cref="iteration_length"/> * number of entries in the training matrix.
*/
@Override
public void iterate() {
int num_pos_events = feedback.size();
for (int i = 0; i < num_pos_events * iterationLength; i++) {
if (i % 1000000 == 999999)
System.err.print(".");
if (i % 100000000 == 99999999)
System.err.println();
// user_id u, item_id_1 i, item_id_2 j
SampleTriple triple = new SampleTriple();
updateFeatures(triple);
}
}
/**
* Sample a pair of items, given a user.
* @param t a SampleTriple specifying a user ID
*/
protected void sampleItemPair(SampleTriple t) {
if (fastSampling) {
t.i = userPosItems[t.u].get(random.nextInt(userPosItems[t.u].size()));
t.j = userNegItems[t.u].get(random.nextInt(userNegItems[t.u].size()));
} else {
IntList user_items = new IntArrayList(feedback.userMatrix().get(t.u));
t.i = user_items.get(random.nextInt(user_items.size()));
do
t.j = random.nextInt(0, maxItemID + 1);
while (feedback.userMatrix().get(t.u, t.j) || feedback.itemMatrix().get(t.j).size() == 0); // don't sample the item if it never has been viewed (maybe unknown item!)
}
}
/**
* Sample a user that has viewed at least one and not all items.
* @return the user ID
*/
protected int sampleUser() {
while (true) {
int u = random.nextInt(maxUserID + 1);
IntCollection user_items = feedback.userMatrix().get(u);
if (user_items.size() == 0 || user_items.size() == maxItemID + 1)
continue;
return u;
}
}
/**
* Sample a triple for BPR learning.
* @param t the SampleTriple to configure
*/
protected void sampleTriple(SampleTriple t) {
t.u = sampleUser();
sampleItemPair(t);
}
/**
* Modified feature update method that exploits attribute sparsity.
* @param t a SampleTriple specifying the user ID and the first and second item IDs
*/
protected void updateFeatures(SampleTriple t) {
double x_uij = predict(t.u, t.i) - predict(t.u, t.j);
IntCollection attr_i = itemAttributes.get(t.i);
IntCollection attr_j = itemAttributes.get(t.j);
// Assumption: attributes are sparse
IntSet attr_i_over_j = new IntArraySet(attr_i);
attr_i_over_j.removeAll(attr_j);
IntSet attr_j_over_i = new IntArraySet(attr_j);
attr_j_over_i.removeAll(attr_i);
double one_over_one_plus_ex = 1 / (1 + Math.exp(x_uij));
for (int a : attr_i_over_j) {
double w_uf = itemAttributeWeightByUser.get(t.u, a);
double uf_update = one_over_one_plus_ex - regularization * w_uf;
itemAttributeWeightByUser.set(t.u, a, w_uf + learnRate * uf_update);
}
for (int a : attr_j_over_i) {
double w_uf = itemAttributeWeightByUser.get(t.u, a);
double uf_update = -one_over_one_plus_ex - regularization * w_uf;
itemAttributeWeightByUser.set(t.u, a, w_uf + learnRate * uf_update);
}
}
/**
*
*/
@Override
public double predict(int user_id, int item_id) {
if ((user_id < 0) || (user_id >= itemAttributeWeightByUser.dim1))
return Double.MIN_VALUE;
if ((item_id < 0) || (item_id > maxItemID))
return Double.MIN_VALUE;
double result = 0;
for (int a : itemAttributes.get(item_id))
result += itemAttributeWeightByUser.get(user_id, a);
return result;
}
/**
* @throws IOException
*/
@Override
public void saveModel(String filename) throws IOException {
PrintWriter writer = Model.getWriter(filename, this.getClass(), VERSION);
saveModel(writer);
writer.flush();
writer.close();
}
@Override
public void saveModel(PrintWriter writer) {
IMatrixExtensions.writeMatrix(writer, itemAttributeWeightByUser);
}
/**
*
*/
@Override
public void loadModel(String filename) throws IOException {
BufferedReader reader = Model.getReader(filename, this.getClass());
loadModel(reader);
reader.close();
}
/**
*
*/
@Override
public void loadModel(BufferedReader reader) throws IOException {
this.itemAttributeWeightByUser = (Matrix<Double>) IMatrixExtensions.readDoubleMatrix(reader, new Matrix<Double>(0, 0, 0.0));
}
/**
*
*/
@Override
public double computeLoss() {
return -1;
}
/**
*
*/
@Override
public String toString() {
return
this.getClass().getName()
+ " reg=" + regularization
+ " numIter=" + numIter
+ " learnRate=" + learnRate
+ " fastSamplingMemoryLimit=" + fastSamplingMemoryLimit
+ " initMean=" + initMean
+ " initStdev=" + initStdev;
}
private class SampleTriple {
int u; // user_id
int i; // item_id positive item
int j; // item_id negative item
}
}