package org.wikibrain.phrases;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.apache.commons.lang3.StringUtils;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.nlp.StringTokenizer;
import org.wikibrain.core.nlp.Token;
import org.wikibrain.utils.WpCollectionUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
/**
* @author Shilad Sen
*/
public class PhraseTokenizer {
private final LinkProbabilityDao dao;
private final double minLinkProbabilityForPhrases = 1E-5;
public PhraseTokenizer(LinkProbabilityDao dao) {
this.dao = dao;
}
public List<String> makePhrases(Language language, String sentence) throws DaoException {
List<String> result = new ArrayList<String>();
for (Token phrase : makePhraseTokens(language, sentence)) {
result.add(phrase.getToken());
}
return result;
}
public List<String> makePhrases(Language language, List<String> words) throws DaoException {
String text = StringUtils.join(words, " ");
int i = 0;
List<Token> tokens = new ArrayList<Token>();
for (String w : words) {
if (i > 0) {
i++; // for space
}
Token t = new Token(i, i + w.length(), text);
if (!t.getToken().equals(w)) {
throw new IllegalStateException();
}
tokens.add(t);
i = t.getEnd();
}
List<String> result = new ArrayList<String>();
for (Token phrase : makePhraseTokens(language, tokens)) {
result.add(phrase.getToken());
}
return result;
}
public List<Token> makePhraseTokens(Language language, Token sentence) throws DaoException {
List<Token> words = new StringTokenizer().getWordTokens(language, sentence);
return makePhraseTokens(language, words);
}
public List<Token> makePhraseTokens(Language language, String sentence) throws DaoException {
List<Token> words = new StringTokenizer().getWordTokens(language, sentence);
return makePhraseTokens(language, words);
}
private static class Mention {
TIntList tokens;
Double probability;
Mention(int beg, int end, double probability) {
tokens = new TIntArrayList();
for (int i = beg; i <= end; i++) {
tokens.add(i);
}
this.probability = probability;
}
boolean intersects(TIntSet used) {
for (int i : tokens.toArray()) {
if (used.contains(i)) {
return true;
}
}
return false;
}
}
public List<Token> makePhraseTokens(Language language, List<Token> words) throws DaoException {
if (words.isEmpty()) {
return new ArrayList<Token>();
}
if (!WpCollectionUtils.isSorted(words)) {
words = new ArrayList<Token>(words);
Collections.sort(words);
}
// Pass 1: Calculate possible phrases
List<Mention> possibles = new ArrayList<Mention>();
for (int i = 0; i < words.size(); i++) {
StringBuilder buffer = new StringBuilder();
for (int j = i; j < words.size(); j++) {
if (buffer.length() > 0) {
buffer.append(' ');
}
buffer.append(words.get(j).getToken());
double prob = dao.getLinkProbability(buffer.toString());
if (prob > minLinkProbabilityForPhrases) {
Mention m = new Mention(i, j, prob);
possibles.add(m);
}
if (!dao.isSubgram(buffer.toString(), true)) {
break;
}
}
}
// Pass 2: build up maximal non-overlapping set of highest-scoring mentions
Collections.sort(possibles, new Comparator<Mention>() {
@Override
public int compare(Mention o1, Mention o2) {
return -1 * o1.probability.compareTo(o2.probability);
}
});
List<Mention> result = new ArrayList<Mention>();
TIntSet used = new TIntHashSet();
for (Mention m : possibles) {
if (!m.intersects(used)) {
used.addAll(m.tokens);
result.add(m);
}
}
// Pass 3: Add any words we missed
for (int i = 0; i < words.size(); i++) {
if (!used.contains(i)) {
result.add(new Mention(i, i, 0.1));
}
}
// Pass 4: Turn words into a sentence
Collections.sort(result, new Comparator<Mention>() {
@Override
public int compare(Mention o1, Mention o2) {
return o1.tokens.min() - o2.tokens.min();
}
});
List<Token> phrases = new ArrayList<Token>();
for (Mention m : result) {
int begToken = m.tokens.min();
int endToken = m.tokens.max();
phrases.add(
new Token(
words.get(begToken).getBegin(),
words.get(endToken).getEnd(),
words.get(begToken).getFullText()
)
);
}
return phrases;
}
}