// Copyright (C) 2010 Steffen Rendle, Zeno Gantner // Copyright (C) 2011 Zeno Gantner, Chris Newell // // This file is part of MyMediaLite. // // MyMediaLite is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // MyMediaLite is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>. package org.mymedialite.itemrec; import it.unimi.dsi.fastutil.ints.IntList; import org.mymedialite.datatype.IBooleanMatrix; import org.mymedialite.datatype.Matrix; import cern.colt.matrix.DoubleMatrix2D; import cern.colt.matrix.impl.DenseDoubleMatrix2D; import cern.colt.matrix.linalg.Algebra; /** * Weighted matrix factorization method proposed by Hu et al. and Pan et al.. * * We use the fast learning method proposed by Hu et al. (alternating least squares), * and we use a global weight to penalize observed/unobserved values. * * Literature: * * Y. Hu, Y. Koren, C. Volinsky: Collaborative filtering for implicit feedback datasets. * ICDM 2008. * http://research.yahoo.net/files/HuKorenVolinsky-ICDM08.pdf * * R. Pan, Y. Zhou, B. Cao, N. N. Liu, R. M. Lukose, M. Scholz, Q. Yang: * One-class collaborative filtering, * ICDM 2008. * http://www.hpl.hp.com/techreports/2008/HPL-2008-48R1.pdf * * This recommender does NOT support incremental updates. * @version 2.03 */ public class WRMF extends MF { /** * C position: the weight/confidence that is put on positive observations. * The alpha value in Hu et al. */ public double cPos = 1; /** * Regularization parameter. */ public double regularization = 0.015; /** * */ public WRMF() { numIter = 15; } /** * */ public void iterate() { // Perform alternating parameter fitting optimize(feedback.userMatrix(), userFactors, itemFactors); optimize(feedback.itemMatrix(), itemFactors, userFactors); } /** * Optimizes the specified data. * @param data data * @param W W * @param H H */ protected void optimize(IBooleanMatrix data, Matrix<Double> W, Matrix<Double> H) { Matrix<Double> HH = new Matrix<Double>(numFactors, numFactors); Matrix<Double> HC_minus_IH = new Matrix<Double>(numFactors, numFactors); double[] HCp = new double[numFactors]; DenseDoubleMatrix2D m = new DenseDoubleMatrix2D(numFactors, numFactors); // Source code comments are in terms of computing the user factors // Works the same with users and items exchanged // (1) Create HH in O(f^2|Items|) // HH is symmetric for (int f_1 = 0; f_1 < numFactors; f_1++) for (int f_2 = 0; f_2 < numFactors; f_2++) { double d = 0; for (int i = 0; i < H.dim1; i++) d += H.get(i, f_1) * H.get(i, f_2); HH.set(f_1, f_2, d); } // (2) Optimize all U // HC_minus_IH is symmetric for (int u = 0; u < W.dim1; u++) { IntList row = data.getEntriesByRow(u); // Create HC_minus_IH in O(f^2|S_u|) for (int f_1 = 0; f_1 < numFactors; f_1++) for (int f_2 = 0; f_2 < numFactors; f_2++) { double d = 0; for (int i : row) //d += H.get(i, f_1) * H.get(i, f_2) * (c_pos - 1); d += H.get(i, f_1) * H.get(i, f_2) * cPos; HC_minus_IH.set(f_1, f_2, d); } // Create HCp in O(f|S_u|) for (int f = 0; f < numFactors; f++) { double d = 0; for (int i : row) //d += H.get(i, f) * c_pos; d += H.get(i, f) * (1 + cPos); HCp[f] = d; } // Create m = HH + HC_minus_IH + reg*I // m is symmetric // The inverse m_inv is symmetric for (int f_1 = 0; f_1 < numFactors; f_1++) for (int f_2 = 0; f_2 < numFactors; f_2++) { double d = HH.get(f_1, f_2) + HC_minus_IH.get(f_1, f_2); if (f_1 == f_2) d += regularization; m.set(f_1, f_2, d); } DoubleMatrix2D m_inv = Algebra.DEFAULT.inverse(m); // Write back optimal W for (int f = 0; f < numFactors; f++) { double d = 0; for (int f_2 = 0; f_2 < numFactors; f_2++) d += m_inv.get(f, f_2) * HCp[f_2]; W.set(u, f, d); } } } /** * */ public double computeLoss() { return -1; } /** * */ public String toString() { return "WRMF numFactors=" + numFactors + " regularization=" + regularization + " cPos=" + cPos + " numIter=" + numIter + " initMean=" + initMean + " initStDev=" + initStDev; } }