package org.wikibrain.sr.vector;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.hash.TIntFloatHashMap;
import org.junit.Test;
import org.wikibrain.matrix.SparseMatrixRow;
import org.wikibrain.matrix.ValueConf;
import org.wikibrain.sr.utils.SimUtils;
import static org.junit.Assert.assertEquals;
/**
* @author Shilad Sen
*/
public class TestGoogleSimilarity {
private static final int NUM_PAGES = 100;
static final int [] ROW1_IDS = { 7, 9, 12, 5, 6, 2 };
static final float [] ROW1_VALS = { 0.3f, 0.5f, 0.2f, 0.7f, 0.8f, 0.1f };
static final int [] ROW2_IDS = { 5, 3, 2, 4, 7 };
static final float [] ROW2_VALS = { 0.8f, 0.1f, 0.2f, 0.4f, 0.5f };
@Test
public void testUtils() {
TIntFloatMap row1 = getMap(ROW1_IDS, ROW1_VALS);
TIntFloatMap row2 = getMap(ROW2_IDS, ROW2_VALS);
double expected = googleSimilarity(row1, row2);
double actual = SimUtils.googleSimilarity(6, 5, 3, NUM_PAGES);
assertEquals(expected, actual, 0.0001);
}
@Test
public void testMap() {
TIntFloatMap row1 = getMap(ROW1_IDS, ROW1_VALS);
TIntFloatMap row2 = getMap(ROW2_IDS, ROW2_VALS);
double expected = googleSimilarity(row1, row2);
double actual = new GoogleSimilarity(NUM_PAGES).similarity(row1, row2);
assertEquals(expected, actual, 0.0001);
actual = new GoogleSimilarity(NUM_PAGES).similarity(row2, row1);
assertEquals(expected, actual, 0.0001);
}
@Test
public void testRows() {
TIntFloatMap map1 = getMap(ROW1_IDS, ROW1_VALS);
TIntFloatMap map2 = getMap(ROW2_IDS, ROW2_VALS);
SparseMatrixRow row1 = getRow(ROW1_IDS, ROW1_VALS);
SparseMatrixRow row2 = getRow(ROW2_IDS, ROW2_VALS);
double expected = googleSimilarity(map1, map2);
double actual = new GoogleSimilarity(NUM_PAGES).similarity(row1, row2);
assertEquals(expected, actual, 0.0001);
actual = new GoogleSimilarity(NUM_PAGES).similarity(row2, row1);
assertEquals(expected, actual, 0.0001);
}
private double googleSimilarity(TIntFloatMap row1, TIntFloatMap row2) {
int na = row1.size();
int nb = row2.size();
int intersection = 0;
for (int id : row1.keys()) {
if (row2.containsKey(id)) {
intersection++;
}
}
return 1.0 - (Math.log(Math.max(na, nb)) - Math.log(intersection)) / (Math.log(NUM_PAGES) - Math.log(Math.min(na, nb)));
}
private SparseMatrixRow getRow(int [] ids, float [] vals) {
assertEquals(ids.length, vals.length);
return new SparseMatrixRow(new ValueConf(), 34, ids, vals);
}
private TIntFloatMap getMap(int []ids, float [] vals) {
assertEquals(ids.length, vals.length);
TIntFloatHashMap map = new TIntFloatHashMap();
for (int i = 0; i < ids.length; i++) {
map.put(ids[i], vals[i]);
}
return map;
}
}