package org.wikibrain.sr.evaluation; import gnu.trove.list.TDoubleList; import gnu.trove.list.array.TDoubleArrayList; import org.apache.commons.io.FileUtils; import org.apache.commons.math3.stat.correlation.PearsonsCorrelation; import org.apache.commons.math3.stat.correlation.SpearmansCorrelation; import org.wikibrain.core.lang.Language; import org.wikibrain.sr.SRResultList; import org.wikibrain.sr.utils.KnownSim; import java.io.File; import java.io.IOException; import java.text.ParseException; import java.util.*; /** * @author Shilad Sen */ public class MostSimilarEvaluationLog extends BaseEvaluationLog<MostSimilarEvaluationLog> { private final List<MostSimilarGuess> guesses = new ArrayList<MostSimilarGuess>(); /** * Precision and recall is measured at these ranks */ private int[] precisionRecallRanks = {1, 5, 10, 20, 50, 100, 500, 1000}; /** * The threshold under which items are not considered relevant. */ private double relevanceThreshold = 0.6; public MostSimilarEvaluationLog() throws IOException { super(); } public MostSimilarEvaluationLog(File logPath) throws IOException { super(logPath); } public MostSimilarEvaluationLog(Map<String, String> config, File logPath) throws IOException { super(config, logPath); } public MostSimilarEvaluationLog(Map<String, String> config, File logPath, Date date) throws IOException { super(config, logPath, date); } public synchronized void record(KnownMostSim kms, SRResultList mostSimilar) throws IOException { record(kms, new MostSimilarGuess(kms, mostSimilar)); } public synchronized void record(KnownMostSim kms, MostSimilarGuess guess) throws IOException { write(kms, guess.toString()); sucessful++; guesses.add(guess); } public double getNDCG() { double sumWeights = 0.0; double ndgc = 0.0; for (MostSimilarGuess guess : guesses) { double w = guess.getObservations().size() - 1; double v = guess.getNDGC(); if (w >= 0.99 && !Double.isNaN(v) && !Double.isInfinite(v)) { ndgc += w * v; sumWeights += w; } } return ndgc / sumWeights; } public double getPenalizedNDCG() { double sumWeights = 0.0; double ndgc = 0.0; for (MostSimilarGuess guess : guesses) { double w = guess.getKnown().getMostSimilar().size() - 1; double v = guess.getPenalizedNDGC(); if (w >= 0.99 && !Double.isNaN(v) && !Double.isInfinite(v)) { ndgc += w * v; sumWeights += w; } } return ndgc / sumWeights; } public PrecisionRecallAccumulator getPrecisionRecall(int n, double threshold) { PrecisionRecallAccumulator pr = new PrecisionRecallAccumulator(n, threshold); for (MostSimilarGuess guess : guesses) { pr.merge(guess.getPrecisionRecall(n, threshold)); } return pr; } public synchronized void recordFailed(KnownMostSim kms) throws IOException { failed++; write(kms, "failed"); } /** * @see BaseEvaluationLog#getSummaryAsMap() * @return */ public Map<String, String> getSummaryAsMap() { Map<String, String> summary = super.getSummaryAsMap(); summary.put("pearsons", Double.toString(getPearsonsCorrelation())); summary.put("spearmans", Double.toString(getSpearmansCorrelation())); summary.put("ndgc", Double.toString(getNDCG())); summary.put("penalizedNdgc", Double.toString(getPenalizedNDCG())); for (int n : precisionRecallRanks) { PrecisionRecallAccumulator pr = getPrecisionRecall(n, relevanceThreshold); summary.put("num-"+n, Integer.toString(pr.getRetrievedIrrelevant() + pr.getRetrievedRelevant())); summary.put("mean-"+n, Double.toString(pr.getMeanRelevance())); summary.put("precision-"+n, Double.toString(pr.getPrecision())); summary.put("recall-"+n, Double.toString(pr.getRecall())); } return summary; } @Override public void merge(MostSimilarEvaluationLog eval) throws IOException { super.merge(eval); guesses.addAll(eval.guesses); } public List<MostSimilarEvaluationLog> getChildEvaluations() throws IOException, ParseException { List<MostSimilarEvaluationLog> evals = new ArrayList<MostSimilarEvaluationLog>(); for (File file : children) { evals.add(read(file)); } return evals; } public double getSpearmansCorrelation() { TDoubleList actual = new TDoubleArrayList(); TDoubleList expected = new TDoubleArrayList(); for (MostSimilarGuess msg : guesses) { for (MostSimilarGuess.Observation o : msg.getObservations()) { if (!Double.isInfinite(o.estimate) && !Double.isNaN(o.estimate)) { actual.add(o.actual); expected.add(o.estimate); } } } if (actual.size() < 2) { return Double.NaN; } else { return new SpearmansCorrelation().correlation(actual.toArray(), expected.toArray()); } } public double getPearsonsCorrelation() { TDoubleList actual = new TDoubleArrayList(); TDoubleList expected = new TDoubleArrayList(); for (MostSimilarGuess msg : guesses) { for (MostSimilarGuess.Observation o : msg.getObservations()) { if (!Double.isInfinite(o.estimate) && !Double.isNaN(o.estimate)) { actual.add(o.actual); expected.add(o.estimate); } } } if (actual.size() < 2) { return Double.NaN; } else { return new PearsonsCorrelation().correlation(actual.toArray(), expected.toArray()); } } private synchronized void write(KnownMostSim kms, String result) throws IOException { StringBuilder sb = new StringBuilder(); sb.append("entry\t") .append(kms.getLanguage()) .append("\t") .append(cleanPhrase(kms.getPhrase())) .append("\t") .append(kms.getPageId()) .append("\t"); int rank = 0; for (KnownSim ks : kms.getMostSimilar()) { if (rank > 0) sb.append("|"); sb.append(ks.wpId2) .append("@") .append(ks.similarity) .append("@") .append(cleanPhrase(ks.phrase2)); rank++; } sb.append("\t").append(result); write(sb.toString()); } private String cleanPhrase(String phrase) { return phrase.replace("|", "").replaceAll("\\s+", " "); } static public MostSimilarEvaluationLog read(File path) throws IOException, ParseException { Date start = null; Map<String, String> config = new HashMap<String, String>(); MostSimilarEvaluationLog eval = null; for (String line : FileUtils.readLines(path, "utf-8")) { if (line.endsWith("\n")) { line = line.substring(0, line.length() - 1); } String tokens[] = line.split("\t"); if (tokens[0].equals("start")) { start = SimilarityEvaluationLog.parseDate(tokens[1]); } else if (tokens[0].equals("config")) { config.put(tokens[1], tokens[2]); } else if (tokens[0].equals("merge")) { eval.merge(read(new File(tokens[1]))); } else if (tokens[0].equals("entry")) { if (eval == null) { eval = new MostSimilarEvaluationLog(config, null, start); } List<KnownSim> sims = new ArrayList<KnownSim>(); Language lang = Language.getByLangCode(tokens[1]); String phrase1 = tokens[2]; int localId1 = Integer.valueOf(tokens[3]); for (String ksStr : tokens[4].split("[|]")) { String ksTokens[] = ksStr.split("[@]"); int localId2 = Integer.valueOf(ksTokens[0]); double sim = Double.valueOf(ksTokens[1]); String phrase2 = ksTokens[2]; sims.add(new KnownSim(phrase1, phrase2, localId1, localId2, sim, lang)); } KnownMostSim ks = new KnownMostSim(sims); String val = tokens[5]; if (val.equals("failed")) { eval.recordFailed(ks); } else { eval.record(ks, new MostSimilarGuess(ks, val)); } } else { throw new IllegalStateException("invalid event in log " + path + ": " + line); } } return eval; } public void setPrecisionRecallRanks(int[] precisionRecallRanks) { this.precisionRecallRanks = precisionRecallRanks; } public void setRelevanceThreshold(double relevanceThreshold) { this.relevanceThreshold = relevanceThreshold; } }