package org.wikibrain.sr.phrasesim;
import gnu.trove.list.TIntList;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.procedure.TIntProcedure;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.utils.Leaderboard;
import java.io.Serializable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A cosimilarity matrix that is dense, but can be expanded.
*
* @author Shilad Sen
*/
public class CosimilarityMatrix implements Serializable {
private static final double EXPANSION_FRACTION = 1.3;
private static final Logger LOGGER = LoggerFactory.getLogger(CosimilarityMatrix.class);
private float[][] matrix = new float[0][0];
private TIntIntMap sparse2Dense = new TIntIntHashMap();
private int[] dense2Sparse = new int[0];
private boolean [] completed = new boolean[0];
public synchronized void update(int sparseId, SRResultList neighbors) {
int denseId;
if (sparse2Dense.containsKey(sparseId)) {
denseId = sparse2Dense.get(sparseId);
} else {
expandIfNecessary();
denseId = sparse2Dense.size();
sparse2Dense.put(sparseId, denseId);
dense2Sparse[denseId] = sparseId;
}
for (SRResult r : neighbors) {
if (!sparse2Dense.containsKey(r.getId())) continue;
int denseId2 = sparse2Dense.get(r.getId());
matrix[denseId][denseId2] = (float) r.getScore();
matrix[denseId2][denseId] = (float) r.getScore();
}
completed[denseId] = true;
}
public float[] getVector(int id) {
if (sparse2Dense.containsKey(id)) {
return matrix[sparse2Dense.get(id)];
} else {
return null;
}
}
public synchronized int size() {
return sparse2Dense.size();
}
public SRResultList mostSimilar(int id, int maxResults, TIntSet candidateIds) {
int n;
int denseId;
boolean denseCandidateIds[] = null;
synchronized (this) {
n = sparse2Dense.size(); // this can be lock free because of trove's implementation
denseId = sparse2Dense.get(id);
if (candidateIds != null) {
denseCandidateIds = new boolean[candidateIds.size()];
final boolean[] finalDenseCandidateIds = denseCandidateIds;
candidateIds.forEach(new TIntProcedure() {
@Override
public boolean execute(int id2) {
finalDenseCandidateIds[sparse2Dense.get(id2)] = true;
return true;
}
});
}
}
Leaderboard top = new Leaderboard(maxResults);
for (int i = 0; i < n; i++) {
if (denseCandidateIds == null || denseCandidateIds[i]) {
top.tallyScore(i, matrix[denseId][i]);
}
}
SRResultList results = top.getTop();
for (int i = 0; i < results.numDocs(); i++) {
results.setId(i, dense2Sparse[results.getId(i)]);
}
return results;
}
public double[][] cosimilarity(int rows[], int columns[]) {
double cosims[][] = new double[rows.length][columns.length];
int denseRowIds[] = new int[rows.length];
int denseColIds[] = new int[columns.length];
synchronized (this) {
for (int i = 0; i < rows.length; i++) {
int rowId = rows[i];
denseRowIds[i] = sparse2Dense.containsKey(rowId) ? sparse2Dense.get(rowId) : -1;
}
for (int i = 0; i < columns.length; i++) {
int colId = columns[i];
denseColIds[i] = sparse2Dense.containsKey(colId) ? sparse2Dense.get(colId) : -1;
}
}
for (int i = 0; i < denseRowIds.length; i++) {
for (int j = 0; j < denseColIds.length; j++) {
if (denseRowIds[i] >= 0 && denseColIds[j] >= 0) {
cosims[i][j] = matrix[denseRowIds[i]][denseColIds[j]];
}
}
}
return cosims;
}
private synchronized void expandIfNecessary() {
if (sparse2Dense.size() < dense2Sparse.length) {
return;
}
if (dense2Sparse.length != sparse2Dense.size()) {
throw new IllegalStateException();
}
if (dense2Sparse.length != matrix.length) {
throw new IllegalStateException();
}
if (dense2Sparse.length > 0 && dense2Sparse.length != matrix[0].length) {
throw new IllegalStateException();
}
int oldn = sparse2Dense.size();
final int n = (int) (Math.max(500, oldn) * EXPANSION_FRACTION);
LOGGER.info("expanding cosimilarity matrix to length " + n);
float newCosims[][] = new float[n][n];
for (int i = 0; i < oldn; i++) {
System.arraycopy(matrix[i], 0, newCosims[i], 0, oldn);
}
int newDense2Sparse[] = new int[n];
boolean newCompleted[] = new boolean[n];
System.arraycopy(dense2Sparse, 0, newDense2Sparse, 0, oldn);
System.arraycopy(newCompleted, 0, newCompleted, 0, oldn);
matrix = newCosims;
dense2Sparse = newDense2Sparse;
completed = newCompleted;
LOGGER.info("finished expanding cosimilarity matrix");
}
public float similarity(int id1, int id2) {
if (sparse2Dense.containsKey(id1) && sparse2Dense.containsKey(id2)) {
return matrix[sparse2Dense.get(id1)][sparse2Dense.get(id2)];
} else {
return 0f;
}
}
/**
* Returns all ids whose "update" method has successfully completed.
* @return
*/
public TIntSet getCompleted() {
TIntSet result = new TIntHashSet();
for (int i = 0; i < completed.length; i++) {
if (completed[i]) {
result.add(dense2Sparse[i]);
}
}
return result;
}
}