package org.wikibrain.integration;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.core.cmd.Env;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.lang.Language;
import org.wikibrain.sr.Explanation;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.dataset.DatasetDao;
import org.wikibrain.sr.utils.ExplanationFormatter;
import java.io.IOException;
import java.sql.SQLException;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
/**
* @author Shilad Sen
*/
public class LocalEnsembleSRIT {
private static final Language SIMPLE = Language.getByLangCode("simple");
// Set this to true
private static final Boolean SKIP_RESTORE = true;
@BeforeClass
public static void prepareDump() throws ConfigurationException, IOException, SQLException {
if (SKIP_RESTORE) {
return;
}
TestDB db = TestUtils.getTestDb();
db.restoreLucene();
}
@Ignore
@Test
public void testESAAccuracy() throws Exception {
testAccuracy("ESA", 0.50, 0.58, 0);
testExplain("ESA", "President", "Obama");
}
@Ignore
@Test
public void testMilneWittenAccuracy() throws Exception {
testAccuracy("milnewitten", 0.35, 0.37, 0);
testExplain("milnewitten", "President", "Obama");
}
public void testAccuracy(String srName, double minPearson, double minSpearman, int maxNoPred) throws ConfigurationException, DaoException {
Env env = TestUtils.getEnv();
SRMetric sr = env.getConfigurator().get(SRMetric.class, srName, "language", "simple");
DatasetDao datasetDao = new DatasetDao();
Dataset ds = datasetDao.get(SIMPLE, "wordsim353.txt");
/*
CrossValidation cv = new CrossValidation();
List<Dataset> allTrain = new ArrayList<Dataset>();
List<Dataset> allTest = new ArrayList<Dataset>();
CrossValidation.makeFolds(ds.split(7), allTrain, allTest);
for (int i = 0; i < allTrain.size(); i++) {
sr.trainDefaultSimilarity(allTrain.get(i));
sr.trainSimilarity(allTrain.get(i));
cv.evaluate(sr, allTest.get(i));
}
System.out.println("results for " + srName);
System.out.println("\tpearson: " + cv.getPearson());
System.out.println("\tspearman: " + cv.getSpearman());
System.out.println("\tmissing: " + cv.getMissing());
System.out.println("\tfailed: " + cv.getFailed());
assertTrue(cv.getPearson() >= minPearson);
assertTrue(cv.getSpearman() >= minSpearman);
assertTrue(cv.getMissing() + cv.getFailed() <= maxNoPred);
*/
}
public void testExplain(String srName, String phrase1, String phrase2) throws ConfigurationException, DaoException {
Env env = TestUtils.getEnv();
DatasetDao datasetDao = new DatasetDao();
Dataset ds = datasetDao.get(SIMPLE, "wordsim353.txt");
SRMetric sr = env.getConfigurator().get(SRMetric.class, srName, "language", "simple");
sr.trainSimilarity(ds);
ExplanationFormatter formatter = env.getConfigurator().get(ExplanationFormatter.class);
SRResult result = sr.similarity(phrase1, phrase2, true);
System.out.println(srName + " explanation for " + phrase1 + ", " + phrase2 + " is:");
assertNotNull(result.getExplanations());
for (Explanation ex : result.getExplanations()) {
System.out.println("\t" + formatter.formatExplanation(ex));
}
}
}