package org.wikibrain.sr.dataset;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.core.cmd.Env;
import org.wikibrain.core.cmd.EnvBuilder;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.nlp.Dictionary;
import org.wikibrain.sr.utils.KnownSim;
import org.wikibrain.sr.wikify.Corpus;
import org.wikibrain.sr.wikify.WbCorpusLineReader;
import org.wikibrain.utils.Scoreboard;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.regex.Pattern;
/**
* <p>
* Creates a fake gold standard. This is useful for training
* language editions that do not have existing gold standards.
* </p>
*
* <p>
* A gold standard of size n is created by selecting n target words,
* and computing the pointwise mutual information of them with k other
* candidates. For each of the n targets w, a candidate c is selected
* randomly, with an exponential weight favoring candidates with
* higher PMI. The value assigned to the pair is the percentile of the
* PMI for that pair.
* </p>
*
* <p>The n target words are selected from the maxTargetRank most
* frequent words after removing stopWordRank stop words. The candidates
* are the maxCandidate most frequent words after removing stop words.
* </p>
* @author Shilad Sen
*/
public class FakeDatasetCreator {
private int stopWordRank = 1000;
private int maxTargetRank = 1000;
private int maxCandidateRank = 30000;
private final Dictionary dictionary;
private final File path;
private final Language lang;
public FakeDatasetCreator(Language lang, File path) throws IOException {
this.lang = lang;
this.path = path;
// Build up map from most common words (ignoring stopwords) to a numeric index
this.dictionary = new Dictionary(lang, Dictionary.WordStorage.IN_MEMORY);
this.dictionary.countNormalizedFile(path);
}
public FakeDatasetCreator(Corpus corpus) throws IOException {
this.dictionary = new Dictionary(corpus.getLanguage(), Dictionary.WordStorage.IN_MEMORY);
this.dictionary.read(corpus.getDictionaryFile());
this.path = corpus.getCorpusFile();
this.lang = corpus.getLanguage();
}
public void setStopWordRank(int stopWordRank) {
this.stopWordRank = stopWordRank;
}
public void setMaxTargetRank(int maxTargetRank) {
this.maxTargetRank = maxTargetRank;
}
public void setMaxCandidateRank(int maxCandidateRank) {
this.maxCandidateRank = maxCandidateRank;
}
public Dataset generate(int numPairs) throws IOException {
// Select a list of the most frequent words.
List<String> frequent = new ArrayList<String>();
Pattern p = Pattern.compile(".*[0-9].*");
for (String word : dictionary.getFrequentUnigrams(maxCandidateRank * 3)) {
if (!p.matcher(word).find() && !Character.isUpperCase(word.charAt(0))) {
frequent.add(word);
if (frequent.size() >= maxCandidateRank) break;
}
}
if (frequent.size() < stopWordRank) {
throw new IllegalArgumentException();
}
frequent = frequent.subList(stopWordRank, frequent.size());
Map<String, Integer> candidates = new HashMap<String, Integer>();
for (String word : frequent) {
candidates.put(word, candidates.size());
}
// Choose a set of targets. Each one will have an entry in the dataset.
List<String> shuffled = new ArrayList<String>(
(frequent.size() > maxTargetRank) ? frequent.subList(0, maxTargetRank) : frequent);
Collections.shuffle(shuffled);
Set<String> targets = new HashSet<String>(
candidates.size() <= numPairs ? shuffled : shuffled.subList(0, numPairs));
// Calculate cooccurrence counts.
Map<String, int[]> cocounts = new HashMap<String, int[]>();
for (String word : targets) cocounts.put(word, new int[candidates.size()]);
Set<String> foundTargets = new HashSet<String>();
for (WbCorpusLineReader.Line line : new WbCorpusLineReader(path)) {
String tokens[] = line.getLine().split("\\s+");
foundTargets.clear();
for (int i = 0; i < tokens.length; i++) {
if (targets.contains(tokens[i])) {
foundTargets.add(tokens[i]);
}
}
if (foundTargets.isEmpty()) continue;
for (String target : foundTargets) {
for (String word : tokens) {
if (candidates.containsKey(word)) {
int i = candidates.get(word);
cocounts.get(target)[i] += 1;
}
}
}
}
// calculate pointwise mutual information
shuffled.clear();
shuffled.addAll(cocounts.keySet());
Collections.shuffle(shuffled);
List<KnownSim> pairs = new ArrayList<KnownSim>();
double base = Math.pow(frequent.size() / 3, 1.0 / numPairs);
for (int i = 0; i < shuffled.size(); i++) {
String target = shuffled.get(i);
int actual[] = cocounts.get(target);
final double pmi[] = new double[actual.length];
Scoreboard<String> board = new Scoreboard<String>(10);
int n1 = dictionary.getUnigramCount(target);
for (int j = 0; j < frequent.size(); j++) {
int n2 = dictionary.getUnigramCount(frequent.get(j));
pmi[j] = Math.log(actual[j] / ((n1 + 5.0) * (n2 + 5.0)));
board.add(frequent.get(j), pmi[j]);
}
Integer range[] = new Integer[frequent.size()];
for (int j = 0; j < range.length; j++) {
range[j] = j;
}
Arrays.sort(range, new Comparator<Integer>() {
@Override
public int compare(Integer i, Integer j) {
return -1 * new Double(pmi[i]).compareTo(pmi[j]);
}
});
int index = (int)Math.round(Math.pow(base, i));
String choice = frequent.get(range[index]);
double percentile = 1.0 - 1.0 * i /shuffled.size();
pairs.add(new KnownSim(target.replace('_', ' '), choice.replace('_', ' '), percentile, lang));
}
return new Dataset("fake", lang, pairs);
}
}