//Copyright (C) 2010, 2011 Zeno Gantner, Chris Newell
//
// This file is part of MyMediaLite.
//
// MyMediaLite is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// MyMediaLite is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>.
package org.mymedialite.correlation;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.mymedialite.datatype.Matrix;
import org.mymedialite.datatype.SymmetricMatrix;
/**
* Class for computing and storing correlations and similarities.
* @version 2.03
*/
public class CorrelationMatrix extends SymmetricMatrix<Float> {
/**
* Number of entities, e.g. users or items.
*/
protected int numEntities;
/**
* @return returns true if the matrix is symmetric, which is generally the case for similarity matrices
*/
public boolean isSymmetric() {
return true;
}
/**
* Creates a CorrelationMatrix object for a given number of entities.
* @param numEntities number of entities
*/
public CorrelationMatrix(int numEntities) {
super(numEntities, 0.0F);
this.numEntities = numEntities;
}
/**
* Creates a correlation matrix.
* Gives out a useful warning if there is not enough memory
* @param numEntities the number of entities
* @return the correlation matrix
*/
public static CorrelationMatrix create(int numEntities) {
CorrelationMatrix cm;
try {
cm = new CorrelationMatrix(numEntities);
} catch (OutOfMemoryError e) {
System.err.println("Too many entities: " + numEntities);
throw e;
}
return cm;
}
/**
* Creates a CorrelationMatrix from the lines of a StreamReader.
* In the first line, we expect to be the number of entities.
* All the other lines have the format
* <pre>
* EntityID1 EntityID2 Correlation
* </pre>
* where EntityID1 and EntityID2 are non-negative integers and Correlation is a floating point number.
* @param reader the StreamReader to read from
*/
public static CorrelationMatrix readCorrelationMatrix(BufferedReader reader) throws IOException {
int numEntities = Integer.parseInt(reader.readLine());
CorrelationMatrix cm = create(numEntities);
// Diagonal values.
for (int i = 0; i < numEntities; i++) cm.set(i, i, 1.0F);
String regex = "[\t ,]"; // tab, space or commma.
String line;
// TODO count number of non-zero entries and record in file?
while ((line = reader.readLine()) != null) {
String[] numbers = line.split(regex);
int i = Integer.parseInt(numbers[0]);
int j = Integer.parseInt(numbers[1]);
float c = Float.parseFloat(numbers[2]);
if (i >= numEntities) throw new IOException("Entity ID is too big: i = " + i);
if (j >= numEntities) throw new IOException("Entity ID is too big: j = " + j);
cm.set(i, j, c);
}
return cm;
}
/**
* Write out the correlations to a StreamWriter.
* @param writer
* A <see cref="StreamWriter"/>
*
*/
public void write(PrintWriter writer) {
writer.println(numEntities);
for (int i = 0; i < numEntities; i++) {
for (int j = i + 1; j < numEntities; j++) {
float val = get(i, j);
if (val != 0f)
writer.println(i + " " + j + " " + val);
}
}
}
/**
* Add an entity to the CorrelationMatrix by growing it to the requested size..
* Note that you still have to correctly compute and set the entity's correlation values
* @param entity_id the numerical ID of the entity
*/
public void addEntity(int entity_id) {
this.grow(entity_id + 1, entity_id + 1);
}
/**
* Sum up the correlations between a given entity and the entities in a collection.
* @param entity_id the numerical ID of the entity
* @param entities a collection containing the numerical IDs of the entities to compare to
* @return the correlation sum
*/
public double sumUp(int entity_id, Collection<Integer> entities) {
if (entity_id < 0 || entity_id >= numEntities) throw new IllegalArgumentException("Invalid entity ID: " + entity_id);
double result = 0;
for (int entity_id2 : entities) {
if (entity_id2 >= 0 && entity_id2 < numEntities)
result += get(entity_id, entity_id2);
}
return result;
}
/**
* Get all entities that are positively correlated to an entity, sorted by correlation.
* @param entity_id the entity ID
* @return a sorted list of all entities that are positively correlated to entity_id
*/
public IntList getPositivelyCorrelatedEntities(int entity_id) {
List<Neighbor> result = new ArrayList<Neighbor>();
for (int i = 0; i < numEntities; i++) {
if(i != entity_id && get(i, entity_id) > 0) {
Neighbor neighbor = new Neighbor(i, get(i, entity_id));
result.add(neighbor);
}
}
Collections.sort(result);
IntList ids = new IntArrayList(result.size());
for(int i = 0; i <result.size() ; i++) {
ids.add(result.get(i).id);
}
return ids;
}
/**
* Get the k nearest neighbors of a given entity.
* @param entity_id the numerical ID of the entity
* @param k the neighborhood size
* @return an array containing the numerical IDs of the k nearest neighbors
*/
public int[] getNearestNeighbors(int entity_id, int k) {
List<Neighbor> entities = new ArrayList<Neighbor>();
for (int i = 0; i < numEntities; i++) {
if(i != entity_id) {
Neighbor neighbor = new Neighbor(i, get(i, entity_id));
entities.add(neighbor);
}
}
Collections.sort(entities);
int[] ids = new int[Math.min(k, entities.size())];
for(int i = 0; i < ids.length; i++) {
ids[i] = entities.get(entities.size() - 1 - i).id;
}
return ids;
}
final class Neighbor implements Comparable<Neighbor> {
int id;
float value;
Neighbor(int id, float value) {
this.id = id;
this.value = value;
}
@Override
public int compareTo(Neighbor o) {
return Float.compare(value, o.value);
}
}
}