package org.wikibrain.sr.evaluation; import gnu.trove.list.TDoubleList; import gnu.trove.list.array.TDoubleArrayList; import org.apache.commons.io.FileUtils; import org.apache.commons.math3.stat.correlation.PearsonsCorrelation; import org.apache.commons.math3.stat.correlation.SpearmansCorrelation; import org.apache.commons.math3.stat.ranking.NaturalRanking; import org.apache.commons.math3.stat.ranking.TiesStrategy; import org.wikibrain.core.lang.Language; import org.wikibrain.sr.SRResult; import org.wikibrain.sr.utils.KnownSim; import java.io.File; import java.io.IOException; import java.text.ParseException; import java.util.*; /** * A class that accumulates similarity() evaluation metrics. * The results can optionally be logged to a file. * * @author Shilad Sen */ public class SimilarityEvaluationLog extends BaseEvaluationLog<SimilarityEvaluationLog> { private final TDoubleList actual = new TDoubleArrayList(); private final TDoubleList estimates = new TDoubleArrayList(); public SimilarityEvaluationLog() throws IOException { super(); } public SimilarityEvaluationLog(File logPath) throws IOException { super(logPath); } public SimilarityEvaluationLog(Map<String, String> config, File logPath) throws IOException { super(config, logPath); } public SimilarityEvaluationLog(Map<String, String> config, File logPath, Date date) throws IOException { super(config, logPath, date); } public synchronized void recordFailed(KnownSim ks) throws IOException { failed++; write(ks, "failed"); } public synchronized void record(KnownSim ks, SRResult estimate) throws IOException { if (estimate == null || Double.isNaN(estimate.getScore()) || Double.isInfinite(estimate.getScore())) { missing++; } else { actual.add(ks.similarity); estimates.add(estimate.getScore()); sucessful++; } write(ks, estimate == null ? "null" : String.valueOf(estimate.getScore())); } private synchronized void write(KnownSim ks, String result) throws IOException { write("entry\t" + ks.language + "\t" + ks.phrase1 + "\t" + ks.phrase2 + "\t" + ks.similarity + "\t" + result +"\n"); } public double getPearsonsCorrelation() { return new PearsonsCorrelation().correlation(actual.toArray(), estimates.toArray()); } public double getSpearmansCorrelation() { return new SpearmansCorrelation().correlation(actual.toArray(), estimates.toArray()); } public double getMeanAbsoluteError() { if (actual.isEmpty()) { return 0.0; } double maeSum = 0.0; for (int i = 0; i < actual.size(); i++) { maeSum += Math.abs(actual.get(i) - estimates.get(i)); } return maeSum / actual.size(); } public double getRootMeanSquareError() { if (actual.isEmpty()) { return 0.0; } double rmsError = 0.0; for (int i = 0; i < actual.size(); i++) { rmsError += (actual.get(i) - estimates.get(i)) * (actual.get(i) - estimates.get(i)); } return Math.sqrt(rmsError / actual.size()); } public List<KnownSimGuess> getGuesses() throws IOException, ParseException { List<KnownSimGuess> guesses = new ArrayList<KnownSimGuess>(); for (String line : FileUtils.readLines(logPath, "utf-8")) { if (line.endsWith("\n")) { line = line.substring(0, line.length() - 1); } String tokens[] = line.split("\t"); if (tokens[0].equals("entry")) { KnownSim ks = new KnownSim(tokens[2], tokens[3], Double.valueOf(tokens[4]), Language.getByFullLangName(tokens[1])); String val = tokens[5]; if (val.equals("failed")) { guesses.add(new KnownSimGuess(ks, Double.NaN)); } else { guesses.add(new KnownSimGuess(ks, Double.valueOf(val))); } } } for (SimilarityEvaluationLog log : getChildEvaluations()) { guesses.addAll(log.getGuesses()); } setRanks(guesses); return guesses; } public static void setRanks(List<KnownSimGuess> guesses) { NaturalRanking nr = new NaturalRanking(TiesStrategy.MAXIMUM); // Part 1: build up pruned lists of actual / estimates excluded NaNs, etc. TDoubleList prunedActual = new TDoubleArrayList(); TDoubleList prunedEstimates = new TDoubleArrayList(); for (KnownSimGuess g : guesses) { if (g.hasGuess()) { prunedActual.add(g.getActual()); prunedEstimates.add((g.getGuess())); } } // Part 2: get ranks double [] actualRanks = nr.rank(prunedActual.toArray()); double [] estimatedRanks = nr.rank(prunedEstimates.toArray()); // Part 3: specify them int i = 0; for (KnownSimGuess g : guesses) { if (g.hasGuess()) { g.setActualRank(1.0 + actualRanks.length - actualRanks[i]); g.setPredictedRank(1.0 + estimatedRanks.length - estimatedRanks[i]); i++; } } if (i != prunedActual.size()) { throw new IllegalStateException(); } } /** * @see BaseEvaluationLog#getSummaryAsMap() * @return */ public Map<String, String> getSummaryAsMap() { Map<String, String> summary = super.getSummaryAsMap(); summary.put("spearmans", Double.toString(getSpearmansCorrelation())); summary.put("pearsons", Double.toString(getPearsonsCorrelation())); summary.put("mae", Double.toString(getMeanAbsoluteError())); summary.put("rms", Double.toString(getRootMeanSquareError())); return summary; } public List<SimilarityEvaluationLog> getChildEvaluations() throws IOException, ParseException { List<SimilarityEvaluationLog> evals = new ArrayList<SimilarityEvaluationLog>(); for (File file : children) { evals.add(read(file)); } return evals; } protected TDoubleList getActual() { return actual; } protected TDoubleList getEstimates() { return estimates; } @Override public void merge(SimilarityEvaluationLog eval) throws IOException { super.merge(eval); SimilarityEvaluationLog seval = (SimilarityEvaluationLog)eval; actual.addAll(seval.actual); estimates.addAll(seval.estimates); } /** * Reads in the similarity evaluation at a particular path. * * @param path * @return */ public static SimilarityEvaluationLog read(File path) throws IOException, ParseException { Date start = null; Map<String, String> config = new HashMap<String, String>(); SimilarityEvaluationLog eval = null; for (String line : FileUtils.readLines(path, "utf-8")) { if (line.endsWith("\n")) { line = line.substring(0, line.length() - 1); } if (line.trim().isEmpty()) { continue; } String tokens[] = line.split("\t"); if (tokens[0].equals("start")) { start = parseDate(tokens[1]); } else if (tokens[0].equals("config")) { config.put(tokens[1], tokens[2]); } else if (tokens[0].equals("merge")) { if (eval == null) { eval = new SimilarityEvaluationLog(config, null, start); } eval.merge(read(new File(tokens[1]))); } else if (tokens[0].equals("entry")) { if (eval == null) { eval = new SimilarityEvaluationLog(config, null, start); } KnownSim ks = new KnownSim(tokens[2], tokens[3], Double.valueOf(tokens[4]), Language.getByFullLangName(tokens[1])); String val = tokens[5]; if (val.equals("failed")) { eval.recordFailed(ks); } else { eval.record(ks, new SRResult(Double.valueOf(val))); } } else { throw new IllegalStateException("invalid event in log " + path + ": " + line); } } eval.logPath = path; return eval; } }