package org.wikibrain.sr.phrasesim; import gnu.trove.set.TIntSet; import gnu.trove.set.hash.TIntHashSet; import org.apache.commons.cli.*; import org.apache.commons.io.FileUtils; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.DefaultOptionBuilder; import org.wikibrain.core.cmd.Env; import org.wikibrain.core.cmd.EnvBuilder; import org.wikibrain.sr.SRMetric; import org.wikibrain.sr.SRResult; import org.wikibrain.sr.SRResultList; import org.wikibrain.utils.ParallelForEach; import org.wikibrain.utils.Procedure; import java.io.File; import java.io.IOException; import java.util.*; import java.util.concurrent.ConcurrentHashMap; /** * @author Shilad Sen */ public class PhraseSimEvaluator { private boolean debug = true; private static List<List<String>> TEST_BUNDLES = Arrays.asList( makeSet("jazz music blues"), makeSet("music math statistics"), makeSet("music brain"), makeSet("brain mind"), makeSet("brain statistics algorithm") ); private int k = 10; private final Env env; public PhraseSimEvaluator(Env env) { this.env = env; } public void evaluate(final List<List<String>> bundles) throws ConfigurationException, IOException { String lc = env.getDefaultLanguage().getLangCode(); File dir = FileUtils.getFile(env.getBaseDir(), "dat/sr/known-phrase/en"); FileUtils.deleteQuietly(dir); final KnownPhraseSim sim = (KnownPhraseSim) env.getConfigurator().get(SRMetric.class, "known-phrase", "language", lc); if (!sim.getDataDir().equals(dir)) { throw new IllegalStateException("expected dir " + dir + ", found " + sim.getDataDir()); } final Map<String, Integer> ids = new ConcurrentHashMap<String, Integer>(); ParallelForEach.loop(bundles, new Procedure<List<String>>() { @Override public void call(List<String> bundle) throws Exception { for (String phrase : bundle) { String s = sim.normalize(phrase); if (!ids.containsKey(s)) { ids.put(s, ids.size()); } int id = ids.get(s); sim.addPhrase(phrase, id); } } }); sim.flushCosimilarity(); sim.trainNormalizer(); int numSamples = 0; int numSampleHits = 0; int numRecommended = 0; int numRecommendedHits = 0; int possible = 0; int numErrors = 0; long before = System.currentTimeMillis(); Random rand = new Random(); for (int i = 0; i < 1000; i++) { // Select a random bundle List<String> bundle = bundles.get(rand.nextInt(bundles.size())); if (bundle.isEmpty()) { continue; } numSamples++; TIntSet bundleIds = new TIntHashSet(); for (String p : bundle) { bundleIds.add(ids.get(sim.normalize(p))); } String target = bundle.iterator().next(); int targetId = ids.get(sim.normalize(target)); int j = 0; boolean hasHit = false; StringBuffer line = new StringBuffer(target).append(": "); SRResultList neighbors = sim.mostSimilar(target, k + 1); if (neighbors == null) { numErrors++; continue; } for (SRResult r : neighbors) { if (r.getId() != targetId) { if (this.debug) line.append( String.format("%s %.3f, ", sim.getPhrase(r.getId()), r.getScore())); if (bundleIds.contains(r.getId())) { hasHit = true; numRecommendedHits++; } numRecommended++; if (++j >= k) { break; } } } if (this.debug) System.out.println(line); possible += bundleIds.size(); if (bundleIds.contains(targetId)) { possible--; } if (hasHit) { numSampleHits++; } } long after = System.currentTimeMillis(); System.out.println("for " + bundles.size() + ", top " + k); System.out.println("Total samples: " + numSamples); System.out.println("Total errors: " + numErrors); System.out.println("Total seconds: " + ((after - before) / 1000.0)); System.out.println("Total samples with hits: " + numSampleHits); System.out.println("Total related items: " + numRecommended); System.out.println("Total related items with hits: " + numRecommendedHits); System.out.println("Precision: " + (1.0 * numRecommendedHits / numRecommended)); System.out.println("Recall: " + (1.0 * numRecommendedHits / possible)); } public void setTopK(int k) { this.k = k; } static private List<String> makeSet(String line) { return new ArrayList<String>(Arrays.asList(line.split(" "))); } public static List<List<String>> readBundles(File f) throws IOException { List<List<String>> bundles = new ArrayList<List<String>>(); for (String line : FileUtils.readLines(f)) { List<String> bundle = new ArrayList<String>(); for (String token : line.split("\t")) { bundle.add(token.trim()); } if (bundle.size() >= 2) { bundles.add(bundle); } } return bundles; } public static void main(String args[]) throws ConfigurationException, IOException { Options options = new Options(); options.addOption( new DefaultOptionBuilder() .withLongOpt("bundles") .withDescription("bundle file with tab separated phrases") .hasArg() .create("b")); options.addOption( new DefaultOptionBuilder() .withLongOpt("topk") .withDescription("number neighbors per phrase") .hasArg() .create("k")); EnvBuilder.addStandardOptions(options); CommandLineParser parser = new PosixParser(); CommandLine cmd; try { cmd = parser.parse(options, args); } catch (ParseException e) { System.err.println( "Invalid option usage: " + e.getMessage()); new HelpFormatter().printHelp("PhraseSimEvaluator", options); return; } Env env = new EnvBuilder(cmd).build(); PhraseSimEvaluator eval = new PhraseSimEvaluator(env); List<List<String>> bundles; if (cmd.hasOption("b")) { bundles = readBundles(new File(cmd.getOptionValue("b"))); } else { bundles = TEST_BUNDLES; } if (cmd.hasOption("k")) { eval.setTopK(Integer.parseInt(cmd.getOptionValue("k"))); } eval.evaluate(bundles); } }