//Copyright (C) 2011 Zeno Gantner, Chris Newell
//
// This file is part of MyMediaLite.
//
// MyMediaLite is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// MyMediaLite is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>.
package org.mymedialite.data;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntArraySet;
import it.unimi.dsi.fastutil.ints.IntCollection;
import it.unimi.dsi.fastutil.ints.IntList;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.management.openmbean.InvalidKeyException;
import org.mymedialite.util.Random;
/**
* Abstract dataset class that implements some common functions.
* @version 2.03
*/
public abstract class DataSet implements IDataSet {
protected IntList users = new IntArrayList();
protected IntList items = new IntArrayList();
@Override
public IntList users() {
return users;
}
@Override
public IntList items() {
return items;
}
@Override
public int size() {
return users.size();
}
@Override
public int maxUserID() {
return maxUserID;
}
protected int maxUserID = -1;
@Override
public int maxItemID() {
return maxItemID;
}
protected int maxItemID = -1;
@Override
public List<IntList> byUser() {
if (byUser == null)
buildUserIndices();
return byUser;
}
/** Rating indices organized by user */
protected List<IntList> byUser;
@Override
public List<IntList> byItem() {
if (byItem == null)
buildItemIndices();
return byItem;
}
/** Rating indices organized by item */
protected List<IntList> byItem;
@Override
public IntList randomIndex() {
if (randomIndex == null || randomIndex.size() != size())
buildRandomIndex();
return randomIndex;
}
private IntList randomIndex;
@Override
public IntList allUsers() {
IntSet resultSet = new IntOpenHashSet();
for (int index = 0; index < users.size(); index++)
resultSet.add(users.getInt(index));
return new IntArrayList(resultSet);
}
@Override
public IntList allItems() {
IntSet resultSet = new IntOpenHashSet();
for (int index = 0; index < items.size(); index++)
resultSet.add(items.getInt(index));
return new IntArrayList(resultSet);
}
@Override
public void buildUserIndices() {
byUser = new ArrayList<IntList>();
for (int u = 0; u <= maxUserID; u++)
byUser.add(new IntArrayList());
// one pass over the data
for (int index = 0; index < size(); index++)
byUser.get(users.getInt(index)).add(index);
}
@Override
public void buildItemIndices() {
byItem = new ArrayList<IntList>();
for (int i = 0; i <= maxItemID; i++)
byItem.add(new IntArrayList());
// One pass over the data
for (int index = 0; index < size(); index++)
byItem.get(items.getInt(index)).add(index);
}
@Override
public void buildRandomIndex() {
if (randomIndex == null || randomIndex.size() != size()) {
randomIndex = new IntArrayList(size());
for (int index = 0; index < size(); index++)
randomIndex.add(index, index);
}
Collections.shuffle(randomIndex, Random.getInstance());
}
@Override
public IntSet getUsers(IntList indices) {
IntSet result_set = new IntArraySet();
for (int index : indices)
result_set.add(users.getInt(index));
return result_set;
}
@Override
public IntSet getItems(IntList indices) {
IntSet result_set = new IntArraySet();
for (int index : indices)
result_set.add(items.getInt(index));
return result_set;
}
@Override
public int getIndex(int user_id, int item_id) {
for (int i = 0; i < size(); i++)
if (users.getInt(i) == user_id && items.getInt(i) == item_id)
return i;
throw new InvalidKeyException("index " + user_id + "' " + item_id + " not found.");
}
@Override
public int getIndex(int user_id, int item_id, IntCollection indexes) {
for (int i : indexes)
if (users.getInt(i) == user_id && items.getInt(i) == item_id)
return i;
throw new InvalidKeyException("index " + user_id + "' " + item_id + " not found.");
}
@Override
public Integer tryGetIndex(int user_id, int item_id) {
for (int i = 0; i < size(); i++)
if (users.getInt(i) == user_id && items.getInt(i) == item_id)
return i;
return null;
}
@Override
public Integer tryGetIndex(int user_id, int item_id, IntCollection indexes) {
for (int i : indexes)
if (users.getInt(i) == user_id && items.getInt(i) == item_id)
return i;
return null;
}
}