package org.wikibrain.phrases;
import com.google.code.externalsorting.ExternalSort;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.lang.LanguageSet;
import org.wikibrain.core.lang.LocalId;
import org.wikibrain.core.lang.StringNormalizer;
import org.wikibrain.core.model.LocalPage;
import org.wikibrain.core.model.Title;
import org.wikibrain.utils.WpIOUtils;
import java.io.*;
import java.nio.charset.Charset;
import java.text.DecimalFormat;
import java.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Base implementation of a phrase analyzer.
* Concrete implementations extending this class need only implement a getCorpus() method.
*/
public abstract class BasePhraseAnalyzer implements PhraseAnalyzer {
private static final Logger LOG = LoggerFactory.getLogger(PhraseAnalyzer.class);
/**
* An entry in the phrase corpus.
* Some implementations may have a local id.
* Others will only have a title.
*/
public static class Entry {
Language language;
int localId = -1;
String title = null;
String phrase;
int count;
public Entry(Language language, int localId, String phrase, int count) {
this.language = language;
this.localId = localId;
this.phrase = phrase;
this.count = count;
}
public Entry(Language language, String title, String phrase, int count) {
this.language = language;
this.title = title;
this.phrase = phrase;
this.count = count;
}
}
private final PrunedCounts.Pruner<String> phrasePruner;
private final PrunedCounts.Pruner<Integer> pagePruner;
private final StringNormalizer normalizer;
protected final PhraseAnalyzerDao phraseDao;
protected final LocalPageDao pageDao;
public BasePhraseAnalyzer(PhraseAnalyzerDao phraseDao, LocalPageDao pageDao, PrunedCounts.Pruner<String> phrasePruner, PrunedCounts.Pruner<Integer> pagePruner) {
this.phrasePruner = phrasePruner;
this.pagePruner = pagePruner;
this.phraseDao = phraseDao;
this.pageDao = pageDao;
this.normalizer = phraseDao.getStringNormalizer();
}
/**
* Concrete implementations must override this method to determine what phrases
* are stored.
*
* @return
* @throws IOException
* @throws DaoException
*/
protected abstract Iterable<Entry> getCorpus(LanguageSet langs) throws IOException, DaoException;
/**
* Loads a specific corpus into the dao.
*
* @throws DaoException
* @throws IOException
*/
@Override
public int loadCorpus(LanguageSet langs) throws DaoException, IOException {
// create temp files for storing corpus entries by phrase and local id.
// these will ultimately be sorted to group together records with the same phrase / id.
File byWpIdFile = File.createTempFile("wp_phrases_by_id", "txt");
byWpIdFile.deleteOnExit();
BufferedWriter byWpId = WpIOUtils.openWriter(byWpIdFile);
File byPhraseFile = File.createTempFile("wp_phrases_by_phrase", "txt");
byPhraseFile.deleteOnExit();
BufferedWriter byPhrase = WpIOUtils.openWriter(byPhraseFile);
// Iterate over each entry in the corpus.
// Throws away entries in languages we don't care about.
// Resolve titles to ids if necessary.
// Write entries to the by phrase / id files.
long numEntries = 0;
long numEntriesRetained = 0;
for (Entry e : getCorpus(langs)) {
if (++numEntries % 1000000 == 0) {
double p = 100.0 * numEntriesRetained / numEntries;
LOG.info("processing entry: " + numEntries +
", retained " + numEntriesRetained +
"(" + new DecimalFormat("#.#").format(p) + "%)");
}
if (!langs.containsLanguage(e.language)) {
continue;
}
if (e.phrase == null || e.phrase.trim().isEmpty()) {
continue;
}
if (e.title != null && e.localId < 0) {
int localId = pageDao.getIdByTitle(new Title(e.title, e.language));
e.localId = (localId <= 0) ? -1 : localId;
}
if (e.localId < 0) {
continue;
}
numEntriesRetained++;
e.phrase = e.phrase.replace("\n", " ").replace("\t", " ");
// phrase is last because it may contain tabs.
String line = e.language.getLangCode() + "\t" + e.localId + "\t" + e.count + "\t" + e.phrase + "\n";
byPhrase.write(e.language.getLangCode() + ":" + normalize(e.language, e.phrase) + "\t" + line);
byWpId.write(e.language.getLangCode() + ":" + e.localId + "\t" + line);
}
byWpId.close();
byPhrase.close();
// sort phrases by phrase / id and load them
sortInPlace(byWpIdFile);
loadFromFile(RecordType.PAGES, byWpIdFile, phrasePruner);
sortInPlace(byPhraseFile);
loadFromFile(RecordType.PHRASES, byPhraseFile, pagePruner);
phraseDao.close();
return (int) Math.min(Integer.MAX_VALUE, numEntriesRetained);
}
/**
* Uses the string's normalizer, but replaces adjacent whitespace white a single space
* @param lang
* @param text
* @return
*/
private String normalize(Language lang, String text) {
return normalizer.normalize(lang, text).replaceAll("\\s+", " ");
}
private static enum RecordType {
PAGES, PHRASES
}
protected void loadFromFile(RecordType ltype, File input, PrunedCounts.Pruner pruner) throws IOException, DaoException {
BufferedReader reader = WpIOUtils.openBufferedReader(input);
String lastKey = null;
int maxBufferSize = 1000;
List<Entry> buffer = new ArrayList<Entry>();
while (true) {
String line = reader.readLine();
if (line == null) {
break;
}
String tokens[] = line.split("\t", 5);
if (tokens.length != 5) {
LOG.warn("invalid line in file " + input + ": " + line);
continue;
}
// if new id, write out buffer and clear it
if (lastKey != null && !tokens[0].equals(lastKey)) {
if (ltype == RecordType.PAGES) {
writePage(buffer, pruner);
} else {
writePhrase(buffer, pruner);
}
buffer.clear();
}
Entry e = new Entry(
Language.getByLangCode(tokens[1]),
new Integer(tokens[2]),
tokens[4],
new Integer(tokens[3])
);
buffer.add(e);
if (buffer.size() > maxBufferSize * 3 / 2) {
LOG.warn("large buffer observed: " + buffer.size() + " for string " + lastKey);
maxBufferSize = buffer.size();
}
lastKey = tokens[0];
}
if (ltype == RecordType.PAGES) {
writePage(buffer, pruner);
} else {
writePhrase(buffer, pruner);
}
}
protected void writePage(List<Entry> pageCounts, PrunedCounts.Pruner pruner) throws DaoException {
if (pageCounts.isEmpty()) {
return;
}
Language lang = pageCounts.get(0).language;
int wpId = pageCounts.get(0).localId;
Map<String, Integer> counts = new HashMap<String, Integer>();
for (Entry e : pageCounts) {
if (e.localId != wpId) throw new IllegalStateException();
if (e.language != lang) throw new IllegalStateException();
if (counts.containsKey(e.phrase)) {
counts.put(e.phrase, counts.get(e.phrase) + e.count);
} else {
counts.put(e.phrase, e.count);
}
}
PrunedCounts<String> pruned = pruner.prune(counts);
if (pruned != null) {
phraseDao.savePageCounts(lang, wpId, pruned);
}
}
protected void writePhrase(List<Entry> pageCounts, PrunedCounts.Pruner pruner) throws DaoException {
if (pageCounts.isEmpty()) {
return;
}
Language lang = pageCounts.get(0).language;
String phrase = normalize(lang, pageCounts.get(0).phrase);
Map<Integer, Integer> counts = new HashMap<Integer, Integer>();
for (Entry e : pageCounts) {
if (!normalize(lang, e.phrase).equals(phrase)) {
LOG.warn("disagreement between phrases " + phrase + " and " + e.phrase);
}
if (e.language != lang) {
LOG.warn("disagreement between languages " + lang+ " and " + e.language);
}
if (counts.containsKey(e.localId)) {
counts.put(e.localId, counts.get(e.localId) + e.count);
} else {
counts.put(e.localId, e.count);
}
}
PrunedCounts<Integer> pruned = pruner.prune(counts);
if (pruned != null) {
phraseDao.savePhraseCounts(lang, phrase, pruned);
}
}
private void sortInPlace(File file) throws IOException {
int maxFiles = Math.max(100, (int) (file.length() / (Runtime.getRuntime().maxMemory() / 20)));
LOG.info("sorting " + file + " using max of " + maxFiles);
Comparator<String> comparator = new Comparator<String>() {
public int compare(String r1, String r2){
return r1.compareTo(r2);}};
List<File> l = ExternalSort.sortInBatch(file, comparator, maxFiles, Charset.forName("utf-8"), null, false);
LOG.info("merging " + file);
ExternalSort.mergeSortedFiles(l, file, comparator, Charset.forName("utf-8"));
LOG.info("finished sorting" + file);
}
@Override
public LinkedHashMap<String, Float> describe(Language language, LocalPage page, int maxPhrases) throws DaoException {
LinkedHashMap<String, Float> result = new LinkedHashMap<String, Float>();
PrunedCounts<String> counts = phraseDao.getPageCounts(language, page.getLocalId(), maxPhrases);
if (counts == null) {
return null;
}
for (String phrase : counts.keySet()) {
result.put(phrase, (float)1.0 * counts.get(phrase) / counts.getTotal());
if (result.size() >= maxPhrases) {
break;
}
}
return result;
}
@Override
public LinkedHashMap<LocalId, Float> resolve(Language language, String phrase, int maxPages) throws DaoException {
LinkedHashMap<LocalId, Float> result = new LinkedHashMap<LocalId, Float>();
PrunedCounts<Integer> counts = phraseDao.getPhraseCounts(language, phrase, maxPages);
if (counts == null) {
return null;
}
for (Integer wpId : counts.keySet()) {
result.put(new LocalId(language, wpId),
(float)1.0 * counts.get(wpId) / counts.getTotal());
if (result.size() >= maxPages) {
break;
}
}
return result;
}
public PhraseAnalyzerDao getDao() {
return phraseDao;
}
}