package org.wikibrain.sr.vector;
import com.typesafe.config.Config;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.hash.TIntFloatHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.model.LocalPage;
import org.wikibrain.lucene.LuceneSearcher;
import org.wikibrain.lucene.QueryBuilder;
import org.wikibrain.lucene.WikiBrainScoreDoc;
import org.wikibrain.lucene.WpIdFilter;
import org.wikibrain.sr.Explanation;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.utils.Leaderboard;
import org.wikibrain.sr.utils.SimUtils;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @author Shilad Sen
*/
public class ESAGenerator implements SparseVectorGenerator {
private static final Logger LOG = LoggerFactory.getLogger(ESAGenerator.class);
private final LuceneSearcher searcher;
private final Language language;
private final LocalPageDao pageDao;
private WpIdFilter conceptFilter = null;
private TIntSet blackListSet;
private final String blackListFilePath;
public ESAGenerator(Language language, LocalPageDao pageDao, LuceneSearcher searcher, String blackListFilePath) {
this.language = language;
this.pageDao = pageDao;
this.searcher = searcher;
this.blackListFilePath = blackListFilePath;
try{
createBlackListSet();
} catch (Exception e){
LOG.info("Could not create Blacklist Set");
}
}
private void createBlackListSet() throws FileNotFoundException {
blackListSet = new TIntHashSet();
if(blackListFilePath == null || blackListFilePath.equals("")) {
LOG.info("Skipping blacklist creation; no blacklist file specified.");
return;
}
File file = new File(blackListFilePath);
Scanner scanner = new Scanner(file);
while(scanner.hasNext()){
blackListSet.add(scanner.nextInt());
}
scanner.close();
}
@Override
public TIntFloatMap getVector(int pageId) throws DaoException {
int luceneId = searcher.getDocIdFromLocalId(pageId, language);
if (luceneId < 0) {
LOG.warn("Unindexed document " + pageId + " in " + language.getEnLangName());
return new TIntFloatHashMap();
}
WikiBrainScoreDoc[] wikibrainScoreDocs = getQueryBuilder()
.setMoreLikeThisQuery(luceneId)
.search();
wikibrainScoreDocs = pruneSimilar(wikibrainScoreDocs);
return SimUtils.normalizeVector(expandScores(wikibrainScoreDocs));
}
@Override
public TIntFloatMap getVector(String phrase) {
QueryBuilder builder = getQueryBuilder().setPhraseQuery(phrase);
if (builder.hasQuery()) {
WikiBrainScoreDoc[] scoreDocs = builder.search();
scoreDocs = SimUtils.pruneSimilar(scoreDocs);
return SimUtils.normalizeVector(expandScores(scoreDocs));
} else {
LOG.warn("Phrase cannot be parsed to get a query. "+phrase);
return null;
}
}
public void setConcepts(File file) throws IOException {
conceptFilter = null;
if (!file.isFile()) {
LOG.warn("concept path " + file + " not a file; defaulting to all concepts");
return;
}
TIntSet ids = new TIntHashSet();
for (String wpId : FileUtils.readLines(file)) {
int wpLocalIDNumb= Integer.valueOf(wpId);
if(!isBlacklisted(wpLocalIDNumb)) {
ids.add(wpLocalIDNumb);
}
}
conceptFilter = new WpIdFilter(ids.toArray());
LOG.warn("installed " + ids.size() + " concepts for " + language);
}
private boolean isBlacklisted(int wpLocalIDNumb) {
return blackListSet.contains(wpLocalIDNumb);
}
@Override
public List<Explanation> getExplanations(int pageID1, int pageID2, TIntFloatMap vector1, TIntFloatMap vector2, SRResult result) throws DaoException {
LocalPage page1=pageDao.getById(language,pageID1);
LocalPage page2=pageDao.getById(language,pageID2);
Leaderboard lb = new Leaderboard(5); // TODO: make 5 configurable
for (int id : vector1.keys()) {
if (vector2.containsKey(id)) {
lb.tallyScore(id, vector1.get(id) * vector2.get(id));
}
}
SRResultList top = lb.getTop();
if (top.numDocs() == 0) {
return Arrays.asList(new Explanation("? and ? share no links", page1, page2));
}
List<Explanation> explanations = new ArrayList<Explanation>();
for (int i = 0; i < top.numDocs(); i++) {
LocalPage p = pageDao.getById(language, top.getId(i));
if (p != null) {
explanations.add(new Explanation("Both ? and ? have similar text to ?", page1, page2, p));
}
}
return explanations;
}
@Override
public List<Explanation> getExplanations(String phrase1, String phrase2, TIntFloatMap vector1, TIntFloatMap vector2, SRResult result) throws DaoException {
Leaderboard lb = new Leaderboard(5); // TODO: make 5 configurable
for (int id : vector1.keys()) {
if (vector2.containsKey(id)) {
lb.tallyScore(id, vector1.get(id) * vector2.get(id));
}
}
SRResultList top = lb.getTop();
if (top.numDocs() == 0) {
return Arrays.asList(new Explanation("? and ? share no tags", phrase1, phrase2));
}
List<Explanation> explanations = new ArrayList<Explanation>();
for (int i = 0; i < top.numDocs(); i++) {
LocalPage p = pageDao.getById(language, searcher.getLocalIdFromDocId(top.getId(i), language));
if (p != null) {
explanations.add(new Explanation("Both ? and ? have similar text to ?", phrase1, phrase2, p));
}
}
return explanations;
}
private QueryBuilder getQueryBuilder() {
QueryBuilder builder = searcher.getQueryBuilderByLanguage(language);
builder.setResolveWikipediaIds(false);
if (conceptFilter != null) {
builder.addFilter(conceptFilter);
}
return builder;
}
/**
* Put data in a scoreDoc into a TIntDoubleHashMap
*
* @param wikibrainScoreDocs
* @return
*/
private TIntFloatMap expandScores(WikiBrainScoreDoc[] wikibrainScoreDocs) {
TIntFloatMap expanded = new TIntFloatHashMap();
for (WikiBrainScoreDoc wikibrainScoreDoc : wikibrainScoreDocs) {
expanded.put(wikibrainScoreDoc.luceneId, wikibrainScoreDoc.score);
}
return expanded;
}
/**
* Prune a WikiBrainScoreDoc array.
* @param wikibrainScoreDocs array of WikiBrainScoreDoc
*/
private WikiBrainScoreDoc[] pruneSimilar(WikiBrainScoreDoc[] wikibrainScoreDocs) {
if (wikibrainScoreDocs.length == 0) {
return wikibrainScoreDocs;
}
int cutoff = wikibrainScoreDocs.length;
double threshold = 0.005 * wikibrainScoreDocs[0].score;
for (int i = 0, j = 100; j < wikibrainScoreDocs.length; i++, j++) {
float delta = wikibrainScoreDocs[i].score - wikibrainScoreDocs[j].score;
if (delta < threshold) {
cutoff = j;
break;
}
}
if (cutoff < wikibrainScoreDocs.length) {
wikibrainScoreDocs = ArrayUtils.subarray(wikibrainScoreDocs, 0, cutoff);
}
return wikibrainScoreDocs;
}
public static class Provider extends org.wikibrain.conf.Provider<SparseVectorGenerator> {
public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
super(configurator, config);
}
@Override
public Class getType() {
return SparseVectorGenerator.class;
}
@Override
public String getPath() {
return "sr.metric.sparsegenerator";
}
@Override
public SparseVectorGenerator get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException {
if (!config.getString("type").equals("esa")) {
return null;
}
if (!runtimeParams.containsKey("language")) {
throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter");
}
Language language = Language.getByLangCode(runtimeParams.get("language"));
ESAGenerator generator = new ESAGenerator(
language,
getConfigurator().get(LocalPageDao.class),
getConfigurator().get(LuceneSearcher.class, config.getString("luceneSearcher")),
getConfig().get().getString("sr.blacklist.path")
);
if (config.hasPath("concepts")) {
try {
generator.setConcepts(FileUtils.getFile(
config.getString("concepts"),
language.getLangCode() + ".txt"));
} catch (IOException e) {
throw new ConfigurationException(e);
}
}
return generator;
}
}
}