package org.wikibrain.sr.evaluation;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.utils.KnownSim;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @see Evaluator
* @author Shilad Sen
*/
public class SimilarityEvaluator extends Evaluator<SimilarityEvaluationLog> {
private static final Logger LOG = LoggerFactory.getLogger(SimilarityEvaluator.class);
public SimilarityEvaluator(File outputDir) {
super(outputDir, "local-similarity");
}
/**
* Adds a crossfold validation of a particular dataset.
* The group of the split is set to the name of the dataset.
* @param ds
* @param numFolds
*/
@Override
public void addCrossfolds(Dataset ds, int numFolds) {
List<Dataset> folds = ds.split(numFolds);
for (int i = 0; i < folds.size(); i++) {
Dataset test = folds.get(i);
List<Dataset> trains = new ArrayList<Dataset>(folds);
trains.remove(i);
addSplit(new Split(ds.getName() + "-fold-" + i, ds.getName(), new Dataset(trains), test));
}
}
@Override
public SimilarityEvaluationLog createResults(File path) throws IOException {
return new SimilarityEvaluationLog(path);
}
@Override
public List<String> getSummaryFields() {
return Arrays.asList(
"date",
"runNumber",
"lang",
"metricName",
"dataset",
"successful",
"missing",
"failed",
"pearsons",
"spearmans",
"resolvePhrases",
"metricConfig",
"disambigConfig"
);
}
@Override
protected SimilarityEvaluationLog evaluateSplit(MonolingualSRFactory factory, Split split, File log, File err, Map<String, String> config) throws DaoException, IOException {
SRMetric metric = factory.create();
metric.trainSimilarity(split.getTrain());
SimilarityEvaluationLog splitEval = new SimilarityEvaluationLog(config, log);
BufferedWriter errFile = new BufferedWriter(new FileWriter(err));
for (KnownSim ks : split.getTest().getData()) {
try {
SRResult result;
if (shouldResolvePhrases()) {
result = metric.similarity(ks.wpId1, ks.wpId2, false);
} else {
result = metric.similarity(ks.phrase1, ks.phrase2, false);
}
splitEval.record(ks, result);
} catch (Exception e) {
LOG.warn("Similarity of " + ks + " failed. Logging error to " + err);
splitEval.recordFailed(ks);
errFile.write("KnownSim failed: " + ks + "\n");
errFile.write("\t" + e.getMessage() + "\n");
for (String frame : ExceptionUtils.getStackFrames(e)) {
errFile.write("\t" + frame + "\n");
}
errFile.write("\n");
errFile.flush();
}
}
IOUtils.closeQuietly(splitEval);
IOUtils.closeQuietly(errFile);
return splitEval;
}
}