package org.wikibrain.sr.evaluation;
import org.apache.commons.lang3.StringUtils;
import org.wikibrain.core.lang.Language;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.utils.KnownSim;
import java.util.*;
/**
* Groups similarities for a particular phrases together to form ranked most similar lists.
*
* @author Shilad Sen
*/
public class MostSimilarDataset {
private static final double DEFAULT_THRESHOLD = Double.NEGATIVE_INFINITY;
private final String name;
private final Language language;
private final Map<String, KnownMostSim> data;
private MostSimilarDataset(Language language, String name) {
this.language = language;
this.name = name;
this.data = new HashMap<String, KnownMostSim>();
}
/**
* @see #MostSimilarDataset(java.util.List)
* @param dataset
*/
public MostSimilarDataset(Dataset dataset) {
this(Arrays.asList(dataset));
}
/**
* Creates a new most similar dataset based on some input datasets.
* KnownSims with similarity less than DEFAULT_THRESHOLD are ignored.
*
* @param datasets
*/
public MostSimilarDataset(List<Dataset> datasets) {
this(datasets, DEFAULT_THRESHOLD);
}
/**
* Creates a new most similar dataset based on some input datasets.
* KnownSims with similarity less than threshold are ignored.
*
* @param datasets
*/
public MostSimilarDataset(List<Dataset> datasets, double threshold) {
if (datasets.isEmpty()) {
throw new IllegalArgumentException();
}
this.language = datasets.get(0).getLanguage();
Map<String, List<KnownSim>> sims = new HashMap<String, List<KnownSim>>();
List<String> names = new ArrayList<String>();
for (Dataset ds : datasets) {
ds.normalize(); // just to be safe
if (ds.getLanguage() != language) {
throw new IllegalArgumentException("All datasets must be the same language");
}
for (KnownSim ks : ds.getData()) {
addToMap(sims, ks);
addToMap(sims, ks.getReversed());
}
names.add(ds.getName());
}
name = StringUtils.join(names, ",") +
((threshold == DEFAULT_THRESHOLD) ? "" : ("+threshold="+threshold));
data = new HashMap<String, KnownMostSim>();
for (String phrase : sims.keySet()) {
KnownMostSim mostSim = new KnownMostSim(sims.get(phrase), threshold);
if (mostSim.getMostSimilar().size() > 0) {
data.put(phrase, mostSim);
}
}
}
public Set<String> getPhrases() {
return data.keySet();
}
public KnownMostSim getSimilarities(String phrase) {
return data.get(phrase);
}
/**
* Returns a new dataset that only contains phrases with at least n KnownSim entries.
* @param n Minimum number of phrases
* @return
*/
public MostSimilarDataset pruneSmallLists(int n) {
MostSimilarDataset pruned = new MostSimilarDataset(language, name + "+pruned=" + n);
for (String phrase : data.keySet()) {
if (data.get(phrase).getMostSimilar().size() >= n) {
pruned.data.put(phrase, data.get(phrase));
}
}
return pruned;
}
private void addToMap(Map<String, List<KnownSim>> sims, KnownSim ks) {
if (!sims.containsKey(ks.phrase1)) {
sims.put(ks.phrase1, new ArrayList<KnownSim>());
}
sims.get(ks.phrase1).add(ks);
}
public String getName() {
return name;
}
public Language getLanguage() {
return language;
}
/**
* Converts the most similar dataset back to a "normal" dataset.
* @return
*/
public Dataset toDataset() {
List<KnownSim> sims = new ArrayList<KnownSim>();
for (KnownMostSim kms : data.values()) {
sims.addAll(kms.getMostSimilar());
}
return new Dataset(name, language, sims);
}
/**
* Returns a list of suitable test cross-validation sets.
* The splits occur along phrases, so all entries for a particular phrase stay in the
* same cross-validation split.
* @param n
* @return
*/
public List<MostSimilarDataset> split(int n) {
List<String> phrases = new ArrayList<String>(data.keySet());
Collections.shuffle(phrases);
List<MostSimilarDataset> result = new ArrayList<MostSimilarDataset>();
for (int i = 0; i < n; i++) {
result.add(new MostSimilarDataset(language, name + "+split-" + i));
}
for (int i = 0; i < phrases.size(); i++) {
String p = phrases.get(i);
result.get(i % n).data.put(p, data.get(p));
}
return result;
}
/**
* @see #split(int)
* @see #toDataset()
* @param n
* @return
*/
public List<Dataset> splitIntoDatasets(int n) {
List<Dataset> result = new ArrayList<Dataset>();
for (MostSimilarDataset msd : split(n)) {
result.add(msd.toDataset());
}
return result;
}
}