package org.wikibrain.sr.vector;
import com.typesafe.config.Config;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.map.hash.TIntFloatHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.matrix.MatrixRow;
import org.wikibrain.matrix.SparseMatrix;
import org.wikibrain.matrix.SparseMatrixRow;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.utils.Leaderboard;
import org.wikibrain.sr.utils.SimUtils;
import org.wikibrain.utils.WpIOUtils;
import java.io.File;
import java.io.IOException;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @author Shilad Sen
*/
public class CosineSimilarity implements VectorSimilarity {
private static final Logger LOG = LoggerFactory.getLogger(CosineSimilarity.class);
private TIntFloatHashMap lengths = new TIntFloatHashMap(); // lengths of each row
private TIntSet idsInResults = new TIntHashSet();
private int maxResults = -1;
private SparseMatrix features;
private SparseMatrix transpose;
@Override
public synchronized void setMatrices(SparseMatrix features, SparseMatrix transpose, File dataDir) throws IOException {
this.features = features;
this.transpose = transpose;
File idCacheFile = new File(dataDir, "cosineSimilarity-ids.bin");
File lengthCacheFile = new File(dataDir, "cosineSimilarity-lengths.bin");
File maxCacheFile = new File(dataDir, "cosineSimilarity-maxResults.bin");
if (lengthCacheFile.exists() && lengthCacheFile.lastModified() >= features.lastModified()
&& idCacheFile.exists() && idCacheFile.lastModified() >= transpose.lastModified()) {
LOG.info("reading matrix information from cache");
lengths = (TIntFloatHashMap) WpIOUtils.readObjectFromFile(lengthCacheFile);
idsInResults = (TIntSet) WpIOUtils.readObjectFromFile(idCacheFile);
maxResults = (Integer) WpIOUtils.readObjectFromFile(maxCacheFile);
} else {
LOG.info("building cached matrix information");
lengths.clear();
idsInResults.clear();
maxResults = 0;
for (SparseMatrixRow row : features) {
lengths.put(row.getRowIndex(), (float) row.getNorm());
maxResults = Math.max(maxResults, row.getNumCols());
}
idsInResults.addAll(transpose.getRowIds());
WpIOUtils.writeObjectToFile(lengthCacheFile, lengths);
WpIOUtils.writeObjectToFile(idCacheFile, idsInResults);
WpIOUtils.writeObjectToFile(maxCacheFile, maxResults);
}
}
@Override
public double similarity(MatrixRow a, MatrixRow b) {
return SimUtils.cosineSimilarity(a, b);
}
@Override
public double similarity(TIntFloatMap vector1, TIntFloatMap vector2) {
return SimUtils.cosineSimilarity(vector1, vector2);
}
@Override
public SRResultList mostSimilar(TIntFloatMap query, int maxResults, TIntSet validIds) throws IOException {
if (validIds != null && validIds.size() < 10000) {
return mostSimilarWithRegularIndex(query, maxResults, validIds);
} else {
return mostSimilarWithInvertedIndex(query, maxResults, validIds);
}
}
private SRResultList mostSimilarWithRegularIndex(TIntFloatMap query, int maxResults, TIntSet validIds) throws IOException {
final Leaderboard leaderboard = new Leaderboard(maxResults);
double rowNorm = norm(query);
for (int id : validIds.toArray()) {
MatrixRow row2 = features.getRow(id);
if (row2 != null) {
double dot = 0.0;
for (int i = 0; i < row2.getNumCols(); i++) {
int id2 = row2.getColIndex(i);
float val2 = query.get(id2);
if (val2 > 0) {
dot += val2 + row2.getColValue(i);
}
}
double l1 = lengths.get(id);
double l2 = rowNorm;
double sim = dot / (l1 * l2);
leaderboard.tallyScore(id, sim);
}
}
return leaderboard.getTop();
}
private SRResultList mostSimilarWithInvertedIndex(TIntFloatMap query, int maxResults, TIntSet validIds) throws IOException {
TIntDoubleHashMap dots = new TIntDoubleHashMap(Math.max(100000, maxResults * 5));
// Eschew a for-each loop here for performance reasons.
int keys[] = query.keys();
for (int i = 0; i < keys.length; i++) {
int id = keys[i];
float val1 = query.get(id);
MatrixRow row2 = transpose.getRow(id);
if (row2 != null) {
int n = row2.getNumCols();
for (int j = 0; j < n; j++) {
int id2 = row2.getColIndex(j);
if (validIds == null || validIds.contains(id2)) {
float val2 = row2.getColValue(j);
dots.adjustOrPutValue(id2, val1 * val2, val1 * val2);
}
}
}
}
final Leaderboard leaderboard = new Leaderboard(maxResults);
double rowNorm = norm(query);
keys = dots.keys();
for (int i = 0; i < keys.length; i++) {
int id = keys[i];
double l1 = lengths.get(id);
double l2 = rowNorm;
double dot = dots.get(id);
double sim = dot / (l1 * l2);
leaderboard.tallyScore(id, sim);
}
return leaderboard.getTop();
}
@Override
public double getMinValue() {
return -1.0;
}
@Override
public double getMaxValue() {
return 1.0;
}
private double norm(TIntFloatMap vector) {
double length = 0;
for (float x : vector.values()) {
length += x * x;
}
return Math.sqrt(length);
}
public static class Provider extends org.wikibrain.conf.Provider<VectorSimilarity> {
public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
super(configurator, config);
}
@Override
public Class getType() {
return VectorSimilarity.class;
}
@Override
public String getPath() {
return "sr.metric.similarity";
}
@Override
public VectorSimilarity get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException {
if (!config.getString("type").equals("cosine")) {
return null;
}
return new CosineSimilarity();
}
}
}