package org.mymedialite.ratingprediction;
//Copyright (C) 2010 Zeno Gantner, Steffen Rendle
//Copyright (C) 2011 Zeno Gantner, Chris Newell
//
//This file is part of MyMediaLite.
//
//MyMediaLite is free software: you can redistribute it and/or modify
//it under the terms of the GNU General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//MyMediaLite is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
//GNU General Public License for more details.
//
//You should have received a copy of the GNU General Public License
//along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>.
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import org.mymedialite.io.Model;
import org.mymedialite.io.VectorExtensions;
import org.mymedialite.util.Recommender;
/**
* Abstract class that uses an average (by entity) rating value for predictions.
* This engine does NOT support online updates.
* @version 2.03
*/
public abstract class EntityAverage extends IncrementalRatingPredictor {
private static final String VERSION = "2.03";
/** The average rating for each entity */
protected DoubleList entity_averages = new DoubleArrayList();
/** The global average rating (default prediction if there is no data about an entity) */
protected double global_average = 0;
/**
* Return the average rating for a given entity
* @param index the entity index
*/
public double get(int index) {
if (index < entity_averages.size()) {
return entity_averages.getDouble(index);
} else {
return global_average;
}
}
/**
* Train the recommender according to the given entity type
* @param entity_ids a list of the relevant entity IDs in the training data
* @param max_entity_id the maximum entity ID
*/
protected void train(IntList entity_ids, int max_entity_id) {
IntList rating_counts = new IntArrayList();
entity_averages = new DoubleArrayList();
for (int i = 0; i <= max_entity_id; i++) {
rating_counts.add(0);
entity_averages.add(0.0D);
}
for (int i = 0; i < ratings.size(); i++) {
int entity_id = entity_ids.getInt(i);
rating_counts.set(entity_id, rating_counts.getInt(entity_id) + 1);
entity_averages.set(entity_id, entity_averages.getDouble(entity_id) + ratings.get(i));
}
global_average = ratings.average();
for (int i = 0; i <= max_entity_id; i++) {
if (rating_counts.getInt(i) != 0) {
entity_averages.set(i, entity_averages.getDouble(i) / rating_counts.getInt(i));
} else {
entity_averages.set(i, global_average);
}
}
}
/**
* Retrain the recommender according to the given entity type.
* @param entity_id the ID of the entity to update
* @param indices list of indices to use for retraining
* @param entity_ids list of all entity IDs in the training data (per rating)
*/
protected void retrain(int entity_id, IntList indices, IntList entity_ids) {
double sum = 0;
int count = 0;
for (int i : indices) {
count++;
sum += ratings.get(i);
}
if (count > 0)
entity_averages.set(entity_id, sum / count);
else
entity_averages.set(entity_id, global_average);
}
@Override
public void saveModel(String filename) throws IOException {
PrintWriter writer = Model.getWriter(filename, this.getClass(), VERSION);
saveModel(writer);
writer.flush();
writer.close();
}
@Override
public void saveModel(PrintWriter writer) throws IOException {
writer.println(global_average);
VectorExtensions.writeVector(writer, entity_averages);
writer.flush();
writer.close();
}
@Override
public void loadModel(String filename) throws IOException {
BufferedReader reader = Model.getReader(filename, this.getClass());
loadModel(reader);
reader.close();
}
@Override
public void loadModel(BufferedReader reader) throws IOException {
this.global_average = Double.parseDouble(reader.readLine());
this.entity_averages = VectorExtensions.readVector(reader);
reader.close();
}
}