package org.wikibrain.sr.vector;
import com.typesafe.config.Config;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.apache.commons.io.FileUtils;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.model.LocalPage;
import org.wikibrain.sr.Explanation;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.utils.Leaderboard;
import java.io.File;
import java.io.IOException;
import java.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Generates an
*
* @author Shilad Sen
*/
public class MostSimilarConceptsGenerator implements SparseVectorGenerator {
private static final Logger LOG = LoggerFactory.getLogger(MostSimilarConceptsGenerator.class);
private final Language language;
private final LocalPageDao pageDao;
private final SRMetric baseMetric;
private final int numConcepts;
private TIntSet conceptIds = null;
public MostSimilarConceptsGenerator(Language language, LocalPageDao pageDao, SRMetric baseMetric, int numConcepts) {
this.language = language;
this.pageDao = pageDao;
this.baseMetric = baseMetric;
this.numConcepts = numConcepts;
}
@Override
public TIntFloatMap getVector(int pageId) throws DaoException {
SRResultList mostSimilar = baseMetric.mostSimilar(pageId, numConcepts, conceptIds);
if (mostSimilar == null) {
return null;
} else {
return mostSimilar.asTroveMap();
}
}
@Override
public TIntFloatMap getVector(String phrase) {
throw new UnsupportedOperationException();
}
@Override
public List<Explanation> getExplanations(String phrase1, String phrase2, TIntFloatMap vector1, TIntFloatMap vector2, SRResult result) throws DaoException {
throw new UnsupportedOperationException();
}
public void setConcepts(File file) throws IOException {
conceptIds = new TIntHashSet();
if (!file.isFile()) {
LOG.warn("concept path " + file + " not a file; defaulting to all concepts");
return;
}
for (String wpId : FileUtils.readLines(file)) {
conceptIds.add(Integer.valueOf(wpId));
}
LOG.warn("installed " + conceptIds.size() + " concepts for " + language);
}
@Override
public List<Explanation> getExplanations(int pageID1, int pageID2, TIntFloatMap vector1, TIntFloatMap vector2, SRResult result) throws DaoException {
LocalPage page1=pageDao.getById(language,pageID1);
LocalPage page2=pageDao.getById(language,pageID2);
Leaderboard lb = new Leaderboard(5); // TODO: make 5 configurable
for (int id : vector1.keys()) {
if (vector2.containsKey(id)) {
lb.tallyScore(id, vector1.get(id) * vector2.get(id));
}
}
SRResultList top = lb.getTop();
if (top.numDocs() == 0) {
return Arrays.asList(new Explanation("? and ? share no similar pages", page1, page2));
}
List<Explanation> explanations = new ArrayList<Explanation>();
for (int i = 0; i < top.numDocs(); i++) {
LocalPage p = pageDao.getById(language, top.getId(i));
if (p != null) {
explanations.add(new Explanation("Both ? and ? are similar to ?", page1, page2, p));
}
}
return explanations;
}
public static class Provider extends org.wikibrain.conf.Provider<SparseVectorGenerator> {
public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
super(configurator, config);
}
@Override
public Class getType() {
return SparseVectorGenerator.class;
}
@Override
public String getPath() {
return "sr.metric.sparsegenerator";
}
@Override
public SparseVectorGenerator get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException {
if (!config.getString("type").equals("mostsimilarconcepts")) {
return null;
}
if (!runtimeParams.containsKey("language")) {
throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter");
}
Language language = Language.getByLangCode(runtimeParams.get("language"));
SRMetric baseMetric = getConfigurator().get(
SRMetric.class,
config.getString("basemetric"),
"language",
language.getLangCode());
MostSimilarConceptsGenerator generator = new MostSimilarConceptsGenerator(
language,
getConfigurator().get(LocalPageDao.class),
baseMetric,
config.hasPath("numConcepts") ? config.getInt("numConcepts") : 500
);
if (config.hasPath("concepts")) {
try {
generator.setConcepts(FileUtils.getFile(
config.getString("concepts"),
language.getLangCode() + ".txt"));
} catch (IOException e) {
throw new ConfigurationException(e);
}
}
return generator;
}
}
}