// Copyright (C) 2010, 2011 Zeno Gantner, Chris Newell // // This file is part of MyMediaLite. // // MyMediaLite is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // MyMediaLite is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>. package org.mymedialite.correlation; import it.unimi.dsi.fastutil.ints.IntList; import it.unimi.dsi.fastutil.ints.IntSet; import java.util.List; import java.util.Set; import org.mymedialite.data.IRatings; import org.mymedialite.datatype.IMatrix; import org.mymedialite.datatype.Pair; import org.mymedialite.datatype.SparseMatrix; import org.mymedialite.datatype.SymmetricMatrix; import org.mymedialite.taxonomy.EntityType; /** * Correlation class for Pearson correlation. * http://en.wikipedia.org/wiki/Pearson_correlation * @version 2.03 */ public class Pearson extends RatingCorrelationMatrix { /** * Shrinkage parameter. */ public float shrinkage = 10; /** * Constructor. Create a Pearson correlation matrix. * @param numEntities the number of entities */ public Pearson(int numEntities) { super(numEntities); } /** * Create a Pearson correlation matrix from given data. * @param ratings the ratings data * @param entityType the entity type, either USER or ITEM * @param shrinkage a shrinkage parameter * @return the complete Pearson correlation matrix */ public static CorrelationMatrix create(IRatings ratings, EntityType entityType, float shrinkage) { Pearson cm; int numEntities = 0; if (entityType == EntityType.USER) { numEntities = ratings.maxUserID() + 1; } else if (entityType == EntityType.ITEM) { numEntities = ratings.maxItemID() + 1; } else { throw new IllegalArgumentException("Unknown entity type: " + entityType); } try { cm = new Pearson(numEntities); } catch (OutOfMemoryError e) { System.err.println("Too many entities: " + numEntities); throw e; } cm.shrinkage = shrinkage; cm.computeCorrelations(ratings, entityType); return cm; } /** * Compute correlations between two entities for given ratings. * @param ratings the rating data * @param entityType the entity type, either USER or ITEM * @param i the ID of first entity * @param j the ID of second entity * @param shrinkage the shrinkage parameter */ public static float computeCorrelation(IRatings ratings, EntityType entityType, int i, int j, float shrinkage) { if (i == j) return 1; IntList ratings1 = (entityType == EntityType.USER) ? ratings.byUser().get(i) : ratings.byItem().get(i); IntList ratings2 = (entityType == EntityType.USER) ? ratings.byUser().get(j) : ratings.byItem().get(j); // get common ratings for the two entities IntSet e1 = (entityType == EntityType.USER) ? ratings.getItems(ratings1) : ratings.getUsers(ratings1); IntSet e2 = (entityType == EntityType.USER) ? ratings.getItems(ratings2) : ratings.getUsers(ratings2); e1.retainAll(e2); int n = e1.size(); if (n < 2) return 0; // Single-pass variant double i_sum = 0; double j_sum = 0; double ij_sum = 0; double ii_sum = 0; double jj_sum = 0; for (int other_entity_id : e1) { // Get ratings double r1 = 0; double r2 = 0; if (entityType == EntityType.USER) { r1 = ratings.get(i, other_entity_id, ratings1); r2 = ratings.get(j, other_entity_id, ratings2); } else { r1 = ratings.get(other_entity_id, i, ratings1); r2 = ratings.get(other_entity_id, j, ratings2); } // Update sums i_sum += r1; j_sum += r2; ij_sum += r1 * r2; ii_sum += r1 * r1; jj_sum += r2 * r2; } double denominator = Math.sqrt( (n * ii_sum - i_sum * i_sum) * (n * jj_sum - j_sum * j_sum) ); if (denominator == 0) return 0; double pmcc = (n * ij_sum - i_sum * j_sum) / denominator; return (float) pmcc * (n / (n + shrinkage)); } /** * Compute correlations for given ratings. * @param ratings the rating data * @param entityType the entity type, either USER or ITEM */ public void computeCorrelations(IRatings ratings, EntityType entityType) { if (entityType != EntityType.USER && entityType != EntityType.ITEM) throw new IllegalArgumentException("entity type must be either USER or ITEM, not " + entityType); List<IntList> ratings_by_other_entity = (entityType == EntityType.USER) ? ratings.byItem() : ratings.byUser(); IMatrix<Integer> freqs = new SymmetricMatrix<Integer>(numEntities, 0); IMatrix<Float> i_sums = new SymmetricMatrix<Float>(numEntities, 0.0F); IMatrix<Float> j_sums = new SymmetricMatrix<Float>(numEntities, 0.0F); IMatrix<Float> ij_sums = new SymmetricMatrix<Float>(numEntities, 0.0F); IMatrix<Float> ii_sums = new SymmetricMatrix<Float>(numEntities, 0.0F); IMatrix<Float> jj_sums = new SymmetricMatrix<Float>(numEntities, 0.0F); for (List<Integer> other_entity_ratings : ratings_by_other_entity) { for (int i = 0; i < other_entity_ratings.size(); i++) { int index1 = other_entity_ratings.get(i); int x = (entityType == EntityType.USER) ? ratings.users().get(index1) : ratings.items().get(index1); // Update pairwise scalar product and frequency for (int j = i + 1; j < other_entity_ratings.size(); j++) { int index2 = other_entity_ratings.get(j); int y = (entityType == EntityType.USER) ? ratings.users().get(index2) : ratings.items().get(index2); double rating1 = ratings.get(index1); double rating2 = ratings.get(index2); // Update sums freqs.set(x, y, freqs.get(x, y) + 1); i_sums.set(x, y, i_sums.get(x, y) + new Float(rating1)); j_sums.set(x, y, j_sums.get(x, y) + new Float(rating2)); ij_sums.set(x, y, ij_sums.get(x, y) + new Float(rating1 * rating2)); ii_sums.set(x, y, ii_sums.get(x, y) + new Float(rating1 * rating1)); jj_sums.set(x, y, jj_sums.get(x, y) + new Float(rating2 * rating2)); } } } // The diagonal of the correlation matrix for (int i = 0; i < numEntities; i++) { set(i, i, 1.0F); } for (int i = 0; i < numEntities; i++) { for (int j = i + 1; j < numEntities; j++) { int n = freqs.get(i, j); if (n < 2) { set(i, j, 0.0F); continue; } double numerator = ij_sums.get(i, j) * n - i_sums.get(i, j) * j_sums.get(i, j); double denominator = Math.sqrt( (n * ii_sums.get(i, j) - i_sums.get(i, j) * i_sums.get(i, j)) * (n * jj_sums.get(i, j) - j_sums.get(i, j) * j_sums.get(i, j))); if (denominator == 0) { this.set(i, j, 0.0F); continue; } double pmcc = numerator / denominator; this.set(i, j, (float) (pmcc * ((n - 1) / (n -1 + shrinkage)))); } } } }