package org.wikibrain.sr.evaluation;
import gnu.trove.set.hash.TIntHashSet;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.wikibrain.core.WikiBrainException;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.utils.ParallelForEach;
import org.wikibrain.utils.Procedure;
import org.wikibrain.utils.WpThreadUtils;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @see Evaluator
*
* @author Shilad Sen
*/
public class MostSimilarEvaluator extends Evaluator<MostSimilarEvaluationLog> {
private static final Logger LOG = LoggerFactory.getLogger(MostSimilarEvaluator.class);
private boolean buildCosimilarityMatrix = false;
// These arguments will be passed to calls to mostSimilar()
private int numMostSimilarResults = 2000;
private TIntHashSet mostSimilarIds = null;
// These arguments are passed to the MostSimilarEvaluationLog
private double relevanceThreshold = 0.6;
private int precisionRecallRanks[] = new int[] {1, 5, 10, 20, 50, 100, 500, 1000 };
public MostSimilarEvaluator(File outputDir) {
super(outputDir, "local-mostSimilar");
}
/**
* Adds a crossfold validation of a particular dataset.
* The group of the split is set to the name of the dataset.
* @param ds
* @param numFolds
*/
@Override
public void addCrossfolds(Dataset ds, int numFolds) {
MostSimilarDataset msd = new MostSimilarDataset(ds);
List<Dataset> folds = msd.splitIntoDatasets(numFolds);
for (int i = 0; i < folds.size(); i++) {
Dataset test = folds.get(i);
List<Dataset> trains = new ArrayList<Dataset>(folds);
trains.remove(i);
addSplit(new Split(ds.getName() + "-fold-" + i, ds.getName(), new Dataset(trains), test));
}
}
@Override
public MostSimilarEvaluationLog createResults(File path) throws IOException {
MostSimilarEvaluationLog results = new MostSimilarEvaluationLog(path);
results.setPrecisionRecallRanks(precisionRecallRanks);
results.setRelevanceThreshold(relevanceThreshold);
return results;
}
@Override
public List<String> getSummaryFields() {
List<String> fields = new ArrayList<String>(Arrays.asList(
"date",
"runNumber",
"lang",
"metricName",
"dataset",
"successful",
"missing",
"failed",
"resolvePhrases",
"pearsons",
"spearmans",
"ndgc",
"penalizedNdgc"
));
for (int i : precisionRecallRanks) {
fields.add("num-" + i);
}
for (int i : precisionRecallRanks) {
fields.add("mean-" + i);
}
for (int i : precisionRecallRanks) {
fields.add("precision-" + i);
}
for (int i : precisionRecallRanks) {
fields.add("recall-" + i);
}
fields.add("metricConfig");
fields.add("disambigConfig");
return fields;
}
/**
* Evaluates a particular split for mostSimilar()
* @param factory
* @param split
* @param log
* @param err
* @param config
* @return
* @throws java.io.IOException
* @throws org.wikibrain.core.dao.DaoException
*/
@Override
protected MostSimilarEvaluationLog evaluateSplit(MonolingualSRFactory factory, Split split, File log, final File err, Map<String, String> config) throws IOException, DaoException, WikiBrainException {
final SRMetric metric = factory.create();
metric.trainMostSimilar(split.getTrain(), numMostSimilarResults, mostSimilarIds);
final MostSimilarEvaluationLog splitEval = new MostSimilarEvaluationLog(config, log);
final BufferedWriter errFile = new BufferedWriter(new FileWriter(err));
final MostSimilarDataset msd = new MostSimilarDataset(split.getTest());
ParallelForEach.iterate(msd.getPhrases().iterator(), WpThreadUtils.getMaxThreads(), 1000, new Procedure<String>() {
@Override
public void call(String phrase) throws Exception {
KnownMostSim kms = msd.getSimilarities(phrase);
try {
SRResultList result;
if (shouldResolvePhrases()) {
result = metric.mostSimilar(kms.getPageId(), numMostSimilarResults, mostSimilarIds);
} else {
result = metric.mostSimilar(phrase, numMostSimilarResults, mostSimilarIds);
}
splitEval.record(kms, result);
} catch (Exception e) {
LOG.warn("Similarity of " + kms.getPhrase() + ", id=" + kms.getPageId() + " failed. Logging error to " + err);
splitEval.recordFailed(kms);
synchronized (errFile) {
errFile.write("KnownSim failed: " + phrase + "\n");
errFile.write("\t" + e.getMessage() + "\n");
for (String frame : ExceptionUtils.getStackFrames(e)) {
errFile.write("\t" + frame + "\n");
}
errFile.write("\n");
errFile.flush();
}
}
}
}, 100);
IOUtils.closeQuietly(splitEval);
IOUtils.closeQuietly(errFile);
return splitEval;
}
public void setMostSimilarIds(TIntHashSet mostSimilarIds) {
this.mostSimilarIds = mostSimilarIds;
}
public void setNumMostSimilarResults(int numMostSimilarResults) {
this.numMostSimilarResults = numMostSimilarResults;
}
public void setPrecisionRecallRanks(int[] precisionRecallRanks) {
this.precisionRecallRanks = precisionRecallRanks;
}
public void setBuildCosimilarityMatrix(boolean buildCosimilarityMatrix) {
this.buildCosimilarityMatrix = buildCosimilarityMatrix;
}
}