package org.wikibrain.sr.synrank;
import com.typesafe.config.Config;
import gnu.trove.list.TIntList;
import gnu.trove.list.linked.TIntLinkedList;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.*;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.model.LocalLink;
import org.wikibrain.core.model.LocalPage;
import org.wikibrain.core.model.NameSpace;
import org.wikibrain.sr.BaseSRMetric;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.disambig.Disambiguator;
import java.util.Arrays;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* An implementation of the SynRank SR metric as described in
*
* G. Quercini, H. Samet Uncovering the spatial relatedness in Wikipedia.
* In Y. Huang, M. Gertz, J. C. Krumm, J. Sankaranarayanan, and M. Schneider, editors,
* Proceedings of SIGSPATIAL 2014.
*
* Soon to be linked at http://www.cs.umd.edu/~hjs/hjsyear.html
*
* @author Shilad Sen
*/
public class SynRank extends BaseSRMetric {
private static final Logger LOG = LoggerFactory.getLogger(SynRank.class);
private final LocalLinkDao linkDao;
private final int numArticles;
private final TIntHashSet dabs;
public SynRank(String name, Language language, LocalPageDao pageDao, Disambiguator dab, LocalLinkDao linkDao, LocalCategoryMemberDao catDao) throws DaoException {
super(name, language, pageDao, dab);
this.linkDao = linkDao;
LOG.info("calculating number of articles");
this.numArticles = pageDao.getCount(DaoFilter.normalPageFilter(language));
LOG.info("found " + this.numArticles + " articles");
this.dabs = new TIntHashSet();
// TODO: fix dabs!
LocalPage dabCat = null;
for (String title: Arrays.asList("Category:Disambiguation pages", "Category:Disambiguation")) {
dabCat = pageDao.getByTitle(getLanguage(), NameSpace.CATEGORY, title);
if (dabCat != null) break;
}
if (dabCat == null) throw new IllegalArgumentException();
for (int id : catDao.getCategoryMemberIds(dabCat)) {
dabs.add(id);
}
LOG.info("identified " + dabs.size() + " disambiguation pages");
}
@Override
public SRConfig getConfig() {
SRConfig config = new SRConfig();
config.maxScore = 1.1f;
config.minScore = 0;
return config;
}
@Override
public SRResult similarity(int pageId1, int pageId2, boolean explanations) throws DaoException {
if (pageId1 == pageId2) {
return new SRResult(1.0);
}
if (dabs.contains(pageId1) || dabs.contains(pageId2)) {
return null;
}
TIntSet ids1 = getLinksTo(pageId1);
TIntSet ids2 = getLinksTo(pageId2);
TIntSet both = new TIntHashSet(ids1);
both.retainAll(ids2);
if (both.isEmpty()) {
return new SRResult(0.0);
}
double pmi = 1.0 * numArticles * both.size() / (ids1.size() * ids2.size());
double boost = Math.log10(both.size());
double graphDistance = graphDistance(pageId1, pageId2);
// String t1 = getLocalPageDao().getById(getLanguage(), pageId1).getTitle().toString();
// String t2 = getLocalPageDao().getById(getLanguage(), pageId2).getTitle().toString();
// System.err.println(String.format("Values for %s, %s, are %.4f, %.4f, %.4f\n", t1, t2, pmi, boost, graphDistance));
return new SRResult(pmi * boost / graphDistance);
}
private TIntSet getLinksTo(int pageId) throws DaoException {
TIntSet ids = new TIntHashSet();
for (LocalLink ll : linkDao.get(new DaoFilter().setDestIds(pageId).setLanguages(getLanguage()))) {
ids.add(ll.getSourceId());
}
return ids;
}
private TIntSet getLinksFrom(int pageId) throws DaoException {
TIntSet ids = new TIntHashSet();
for (LocalLink ll : linkDao.get(new DaoFilter().setSourceIds(pageId).setLanguages(getLanguage()))) {
ids.add(ll.getDestId());
}
return ids;
}
private int graphDistance(int pageId1, int pageId2) throws DaoException {
if (getLinksTo(pageId1).contains(pageId2) || getLinksFrom(pageId1).contains(pageId2)) {
return 1;
} else {
return 2;
}
}
@Override
public SRResultList mostSimilar(int pageId, int maxResults, TIntSet validIds) throws DaoException {
throw new UnsupportedOperationException();
}
public static class Provider extends org.wikibrain.conf.Provider<SRMetric> {
public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
super(configurator, config);
}
@Override
public Class getType() {
return SRMetric.class;
}
@Override
public String getPath() {
return "sr.metric.local";
}
@Override
public SRMetric get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException {
if (!config.getString("type").equals("synrank")) {
return null;
}
if (runtimeParams == null || !runtimeParams.containsKey("language")) {
throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter");
}
Language language = Language.getByLangCode(runtimeParams.get("language"));
Disambiguator dab = getConfigurator().get(Disambiguator.class, config.getString("disambiguator"), "language", language.getLangCode());
try {
SynRank sr = new SynRank(
name,
language,
getConfigurator().get(LocalPageDao.class),
dab,
getConfigurator().get(LocalLinkDao.class),
getConfigurator().get(LocalCategoryMemberDao.class)
);
configureBase(getConfigurator(), sr, config);
return sr;
} catch (DaoException e) {
throw new ConfigurationException(e);
}
}
}
}