package org.wikibrain.sr.evaluation;
import org.junit.Before;
import org.junit.Test;
import org.wikibrain.core.lang.Language;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.utils.KnownSim;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
/**
* @author Shilad Sen
*/
public class TestMostSimilarGuess {
private MostSimilarGuess guess;
@Before
public void createGuess() {
Language en = Language.getByLangCode("en");
KnownMostSim sim = new KnownMostSim(Arrays.asList(
new KnownSim("apple", "tart", 34, 99, 0.6, en),
new KnownSim("apple", "orange", 34, 3, 0.9, en),
new KnownSim("apple", "orange", 34, 3, 0.92, en),
new KnownSim("apple", "black", 34, 188, 0.5, en),
new KnownSim("apple", "shoe", 34, 39, 0.0, en),
new KnownSim("apple", "honeycrisp", 34, 19, 0.95, en),
new KnownSim("apple", "mac", 17, 2, 0.8, en)
));
SRResultList list = new SRResultList(7);
list.set(0, 2, 0.83); // mac, rank 3, sim 0.8
list.set(1, 39, 0.73); // shoe, rank 6, sim 0.0
list.set(2, 911, 0.70); // unknown
list.set(3, 19, 0.68); // honeycrisp, rank 1, sim 0.95
list.set(4, 13, 0.66); // unknown
list.set(5, 93, 0.62); // unknown
list.set(6, 3, 0.60); // orange, rank 2, sim 0.91
guess = new MostSimilarGuess(sim, list);
}
@Test
public void testCreate() {
List<MostSimilarGuess.Observation> obs = guess.getObservations();
assertEquals(7, guess.getLength());
assertEquals(4, obs.size());
assertEquals(1, obs.get(0).rank);
assertEquals(2, obs.get(0).id);
assertEquals(0.83, obs.get(0).estimate, 0.001);
assertEquals(7, obs.get(3).rank);
assertEquals(3, obs.get(3).id);
assertEquals(0.60, obs.get(3).estimate, 0.001);
}
@Test
public void testSerialize() {
String s = guess.toString();
MostSimilarGuess guess2 = new MostSimilarGuess(guess.getKnown(), s);
List<MostSimilarGuess.Observation> obs = guess2.getObservations();
assertEquals(7, guess2.getLength());
assertEquals(4, obs.size());
assertEquals(1, obs.get(0).rank);
assertEquals(2, obs.get(0).id);
assertEquals(0.83, obs.get(0).estimate, 0.001);
assertEquals(7, obs.get(3).rank);
assertEquals(3, obs.get(3).id);
assertEquals(0.60, obs.get(3).estimate, 0.001);
MostSimilarGuess guess3 = new MostSimilarGuess(guess2.getKnown(), "3435|0.9|0.5");
assertEquals(0, guess3.getObservations().size());
assertEquals(3435, guess3.getLength());
}
@Test
public void testNdgc() {
double ndgc = (
(0.80 + 0.00 / Math.log(2+1) + 0.95 / Math.log(4+1) + 0.91 / Math.log(7+1))
/ (0.95 + 0.91 / Math.log(2+1) + 0.80 / Math.log(4+1) + 0.00 / Math.log(7+1)));
assertEquals(ndgc, guess.getNDGC(), 0.001);
}
@Test
public void testPenalizedNdgc() {
int unobservedRank = guess.getLength() * 3;
int unobservedCount = 2;
double unobservedSim = 0.60 / 2;
double s = (
0.80 +
0.00 / Math.log(2+1) +
0.95 / Math.log(4+1) +
0.91 / Math.log(7+1) +
unobservedCount * unobservedSim / Math.log(unobservedRank + 1)
);
double t = (
0.95 +
0.91 / Math.log(2+1) +
0.80 / Math.log(4+1) +
0.60 / Math.log(7+1) +
0.50 / Math.log(unobservedRank + 1) +
0.0 / Math.log(unobservedRank + 1)
);
assertEquals(s / t, guess.getPenalizedNDGC(), 0.001);
}
@Test
public void testPrecisionRecall() {
PrecisionRecallAccumulator pr = guess.getPrecisionRecall(1, 0.7);
assertEquals(pr.getN(), 1);
assertEquals(1.0, pr.getPrecision(), 0.001);
assertEquals(0.333333, pr.getRecall(), 0.001);
pr = guess.getPrecisionRecall(2, 0.7);
assertEquals(0.5, pr.getPrecision(), 0.001);
assertEquals(0.333333, pr.getRecall(), 0.001);
pr = guess.getPrecisionRecall(5, 0.7);
assertEquals(0.6666, pr.getPrecision(), 0.001);
assertEquals(0.6666, pr.getRecall(), 0.001);
}
}