package org.wikibrain.sr.evaluation;
import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.math3.stat.correlation.PearsonsCorrelation;
import org.apache.commons.math3.stat.correlation.SpearmansCorrelation;
import org.junit.Test;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.lang.Language;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.dataset.DatasetDao;
import org.wikibrain.sr.utils.KnownSim;
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Random;
import static junit.framework.Assert.assertEquals;
import static junit.framework.Assert.assertTrue;
/**
* @author Shilad Sen
*/
public class TestSimilarityEvaluation {
@Test
public void testSimple() throws IOException, DaoException {
TDoubleList actual = new TDoubleArrayList();
TDoubleList estimated = new TDoubleArrayList();
Random rand = new Random();
Language simple = Language.getByLangCode("simple");
Dataset ds = new DatasetDao().get(simple, "wordsim353.txt");
File log = File.createTempFile("evaluation", "log");
log.deleteOnExit();
SimilarityEvaluationLog se = new SimilarityEvaluationLog(log);
for (int i = 0; i < ds.getData().size(); i++) {
KnownSim ks = ds.getData().get(i);
if (i % 20 == 0) {
se.recordFailed(ks);
} else if (i % 20 == 1) {
se.record(ks, new SRResult(Double.NaN));
} else if (i % 20 == 2) {
se.record(ks, new SRResult(Double.POSITIVE_INFINITY));
} else {
double v = rand.nextDouble();
se.record(ks, new SRResult(v));
actual.add(ks.similarity);
estimated.add(v);
}
}
assertEquals(353, se.getTotal());
assertEquals(18, se.getFailed());
assertEquals(36, se.getMissing());
assertEquals(353-18-36, se.getSuccessful());
assertEquals(se.getPearsonsCorrelation(), new PearsonsCorrelation().correlation(actual.toArray(), estimated.toArray()), 0.000001);
assertEquals(se.getSpearmansCorrelation(), new SpearmansCorrelation().correlation(actual.toArray(), estimated.toArray()), 0.000001);
IOUtils.closeQuietly(se);
List<String> logLines = FileUtils.readLines(log);
assertTrue(logLines.size() > 300);
}
}