package org.wikibrain.sr.evaluation; import gnu.trove.list.TDoubleList; import gnu.trove.list.array.TDoubleArrayList; import org.apache.commons.io.FileUtils; import org.apache.commons.lang3.StringUtils; import org.apache.commons.math3.stat.correlation.PearsonsCorrelation; import org.apache.commons.math3.stat.correlation.SpearmansCorrelation; import org.junit.Test; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.core.WikiBrainException; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.lang.Language; import org.wikibrain.sr.dataset.Dataset; import org.wikibrain.sr.dataset.DatasetDao; import org.wikibrain.utils.WpIOUtils; import java.io.File; import java.io.IOException; import java.util.List; import static org.junit.Assert.*; /** * @author Shilad Sen */ public class TestEvaluator { @Test public void testSimilarity() throws IOException, DaoException, ConfigurationException, WikiBrainException { DatasetDao dsDao = new DatasetDao(); Language simple = Language.getByLangCode("simple"); File file = WpIOUtils.createTempDirectory("evaluator"); SimilarityEvaluator evaluator = new SimilarityEvaluator(file); evaluator.setWriteToStdout(false); evaluator.addCrossfolds(dsDao.get(simple, "radinsky.txt"), 7); evaluator.addCrossfolds(dsDao.get(simple, "atlasify240.txt"), 7); TestLocalSR.Factory factory = new TestLocalSR.Factory(); SimilarityEvaluationLog eval = evaluator.evaluate(factory); List<String> lines = FileUtils.readLines(FileUtils.getFile(file, "local-similarity", "summary.tsv")); assertEquals(lines.size(), 4); assertFalse(StringUtils.join(lines).contains("null")); assertTrue(StringUtils.join(lines).contains("thisIsTheMetric")); assertTrue(StringUtils.join(lines).contains("thisIsTheDisambiguator")); assertEquals(14, eval.getChildFiles().size()); TDoubleList actual = new TDoubleArrayList(); TDoubleList estimated = new TDoubleArrayList(); int missing = 0; int failed = 0; int successful = 0; int total = 0; assertEquals(14, factory.metrics.size()); assertEquals(14, evaluator.getSplits().size()); for (int i = 0; i < factory.metrics.size(); i++) { Dataset test = evaluator.getSplits().get(i).getTest(); TestLocalSR testSr = factory.metrics.get(i); actual.addAll(testSr.getActual(test.getData())); estimated.addAll(testSr.getEstimated(test.getData())); missing += testSr.getMissing(); failed += testSr.getFailed(); successful += testSr.getSuccessful(); total += testSr.getTotal(); } assertEquals("thisIsTheDisambiguator", eval.getSummaryAsMap().get("disambigConfig")); assertEquals("thisIsTheMetric", eval.getSummaryAsMap().get("metricConfig")); assertEquals(missing, eval.getMissing()); assertEquals(failed, eval.getFailed()); assertEquals(total, eval.getTotal()); assertEquals(successful, eval.getSuccessful()); assertEquals(actual, eval.getActual()); assertEquals(estimated, eval.getEstimates()); assertEquals( new PearsonsCorrelation().correlation(actual.toArray(), estimated.toArray()), eval.getPearsonsCorrelation(), 0.00001 ); assertEquals( new SpearmansCorrelation().correlation(actual.toArray(), estimated.toArray()), eval.getSpearmansCorrelation(), 0.00001 ); } @Test public void testRunNumber() throws IOException, DaoException, ConfigurationException, WikiBrainException { DatasetDao dsDao = new DatasetDao(); Language simple = Language.getByLangCode("simple"); File file = WpIOUtils.createTempDirectory("evaluator"); SimilarityEvaluator simEvaluator = new SimilarityEvaluator(file); simEvaluator.setWriteToStdout(false); simEvaluator.addCrossfolds(dsDao.get(simple, "wordsim353.txt"), 7); simEvaluator.addCrossfolds(dsDao.get(simple, "atlasify240.txt"), 7); MostSimilarEvaluator mostSimEvaluator = new MostSimilarEvaluator(file); mostSimEvaluator.setWriteToStdout(false); mostSimEvaluator.addCrossfolds(dsDao.get(simple, "wordsim353.txt"), 7); mostSimEvaluator.addCrossfolds(dsDao.get(simple, "atlasify240.txt"), 7); TestLocalSR.Factory factory = new TestLocalSR.Factory(); BaseEvaluationLog eval = simEvaluator.evaluate(factory); assertTrue(eval.getChildFiles().get(0).toString().contains("0-")); eval = simEvaluator.evaluate(factory); assertTrue(eval.getChildFiles().get(0).toString().contains("1-")); eval = mostSimEvaluator.evaluate(factory); assertTrue(eval.getChildFiles().get(0).toString().contains("2-")); eval = mostSimEvaluator.evaluate(factory); System.out.println(eval.getChildFiles().get(0)); assertTrue(eval.getChildFiles().get(0).toString().contains("3-")); eval = simEvaluator.evaluate(factory); assertTrue(eval.getChildFiles().get(0).toString().contains("4-")); } }