package org.wikibrain.sr.evaluation;
import org.apache.commons.cli.*;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.conf.DefaultOptionBuilder;
import org.wikibrain.core.WikiBrainException;
import org.wikibrain.core.cmd.Env;
import org.wikibrain.core.cmd.EnvBuilder;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.lang.Language;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.dataset.DatasetDao;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
/**
* @author Shilad Sen
*/
public class EvaluationMain {
private static final int DEFAULT_FOLDS = 7;
public static void main(String args[]) throws ConfigurationException, DaoException, IOException, WikiBrainException {
Options options = new Options();
//Specify for universal metric
options.addOption(
new DefaultOptionBuilder()
.hasArg()
.withLongOpt("universal")
.withDescription("set a universal metric")
.create("u"));
//Specify the Dataset
options.addOption(
new DefaultOptionBuilder()
.hasArgs()
.withValueSeparator(',')
.withLongOpt("gold")
.withDescription("the set of gold standard datasets to train on")
.create("g"));
// Use an existing pre-trained metric.
options.addOption(
new DefaultOptionBuilder()
.withLongOpt("pretrained")
.withDescription("use an existing pretrained metric")
.create("a"));
//Specify the Metrics
options.addOption(
new DefaultOptionBuilder()
.hasArg()
.withLongOpt("metric")
.withDescription("set a local metric")
.create("m"));
//specify the output directory
options.addOption(
new DefaultOptionBuilder()
.hasArg()
.withLongOpt("outputDir")
.withDescription("Specify the output directory")
.create("o"));
//reload a saved dataset
options.addOption(
new DefaultOptionBuilder()
.hasArg()
.withLongOpt("reload")
.withDescription("reload a previously stored split dataset")
.create("r"));
//Cross-validation mode
options.addOption(
new DefaultOptionBuilder()
.hasArg()
.withLongOpt("cross-validation-mode")
.withDescription("Set cross validation mode (none, within-dataset, between-dataset)")
.create("x"));
// Prediction mode
options.addOption(
new DefaultOptionBuilder()
.hasArg()
.withLongOpt("prediction-mode")
.withDescription("Set prediction mode (similarity, mostsimilar)")
.create("p"));
//Specify the Folds
options.addOption(
new DefaultOptionBuilder()
.hasArg()
.withLongOpt("folds")
.withDescription("set the number of folds to evaluate on")
.create("k"));
//Resolve phrases to ids
options.addOption(
new DefaultOptionBuilder()
.withLongOpt("resolve")
.withDescription("resolve phrases to ids")
.create("v"));
//Resolve phrases to ids
options.addOption(
new DefaultOptionBuilder()
.withLongOpt("buildMostSimilarCache")
.withDescription("build most similar cache matrices")
.create("z"));
EnvBuilder.addStandardOptions(options);
CommandLineParser parser = new PosixParser();
CommandLine cmd;
try {
cmd = parser.parse(options, args);
} catch (ParseException e) {
System.err.println("Invalid option usage: " + e.getMessage());
new HelpFormatter().printHelp("MetricTrainer", options);
System.exit(1);
return; // to appease the compiler
}
Env env = new EnvBuilder(cmd)
.setProperty("sr.metric.training", !cmd.hasOption("a"))
.build();
Configurator c = env.getConfigurator();
DatasetDao dsDao = c.get(DatasetDao.class);
int folds = cmd.hasOption("k")
? Integer.parseInt(cmd.getOptionValue("k" ))
: DEFAULT_FOLDS;
if (cmd.hasOption("u")) { // TODO: support universal evaluations
throw new UnsupportedOperationException();
}
if (!cmd.hasOption("u") && !cmd.hasOption("m")){
System.err.println("Must specify a metric to evaluate.");
new HelpFormatter().printHelp("MetricTrainer", options);
System.exit(1);
return; // to appease the compiler
}
if (cmd.hasOption("u") && cmd.hasOption("m")){
System.err.println("Can only operate on one metric at a time");
new HelpFormatter().printHelp("MetricTrainer", options);
System.exit(1);
return;
}
if (cmd.hasOption("r")){
throw new UnsupportedOperationException();
} else if (!cmd.hasOption("g")){
System.err.println("Must specify a dataset using either -g or -r");
new HelpFormatter().printHelp("MetricTrainer", options);
System.exit(1);
return;
}
if (cmd.hasOption("u")) {
throw new UnsupportedOperationException(); // TODO: implement universal metrics
}
Language lang = env.getLanguages().getDefaultLanguage();
List<Dataset> datasets = new ArrayList<Dataset>();
String mode = cmd.hasOption("x") ? cmd.getOptionValue("x") : "within-dataset";
for (String dsName : cmd.getOptionValues("g")) {
datasets.addAll(dsDao.getDatasetOrGroup(lang, dsName));
}
String outputDir = cmd.hasOption("o")
? cmd.getOptionValue("o")
: c.getConf().get().getString("sr.dataset.records");
Evaluator evaluator;
if (!cmd.hasOption("p") || cmd.getOptionValue("p").equals("similarity")) {
evaluator = new SimilarityEvaluator(new File(outputDir));
} else if (cmd.getOptionValue("p").equals("mostsimilar")) {
evaluator = new MostSimilarEvaluator(new File(outputDir));
if (cmd.hasOption("z")) {
((MostSimilarEvaluator)evaluator).setBuildCosimilarityMatrix(true);
}
} else {
System.err.println("Invalid prediction mode. usage:");
new HelpFormatter().printHelp("MetricTrainer", options);
System.exit(1);
return; // to appease the compiler
}
if (cmd.hasOption("v")) {
evaluator.setResolvePhrases(true);
}
if (mode.equals("none")) {
Dataset all = new Dataset(datasets);
evaluator.addSplit(new Split(all.getName(), all.getName(), all, all));
} else if (mode.equals("within-dataset")) {
for (Dataset ds : datasets) {
evaluator.addCrossfolds(ds, folds);
}
} else if (mode.equals("across-dataset")) {
evaluator.addCrossfolds(new Dataset(datasets), folds);
} else {
System.err.println("Unknown mode: " + mode);
System.exit(1);
}
MonolingualSRFactory factory;
if (cmd.hasOption("a")) {
SRMetric sr = env.getConfigurator().get(
SRMetric.class,
cmd.getOptionValue("m"),
"language",
lang.getLangCode()
);
factory = new PretrainedSRFactory(sr);
} else {
factory = new ConfigMonolingualSRFactory(
lang, env.getConfigurator(), cmd.getOptionValue("m"));
}
evaluator.evaluate(factory);
}
}