package org.wikibrain.sr.vector;
import com.typesafe.config.Config;
import gnu.trove.function.TFloatFunction;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.hash.TIntFloatHashMap;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.*;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.model.LocalLink;
import org.wikibrain.core.model.LocalPage;
import org.wikibrain.core.model.NameSpace;
import org.wikibrain.sr.Explanation;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.utils.Leaderboard;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Generates a sparse vector containing a "1" for each inbound or outbound link
* for a page.
*
* @author Shilad Sen
*/
public class LinkGenerator implements SparseVectorGenerator {
public static enum LinkType {
IN,
OUT,
DIRECT
}
private static final Logger LOG = LoggerFactory.getLogger(LinkGenerator.class);
private boolean outLinks;
private final LocalLinkDao linkDao;
private final LocalPageDao pageDao;
private final Language language;
private final TIntIntMap linkCounts = new TIntIntHashMap();
private boolean weightByPopularity = false;
private boolean logTransform = false;
private final int numPages;
private TIntSet blackListSet;
private final String blackListFilePath;
public LinkGenerator(Language language, LocalLinkDao linkDao, LocalPageDao pageDao, boolean outLinks, String blackListFilePath) throws DaoException, FileNotFoundException {
this.language = language;
this.linkDao = linkDao;
this.outLinks = outLinks;
this.pageDao = pageDao;
this.blackListFilePath = blackListFilePath;
numPages = pageDao.getCount(
new DaoFilter().setLanguages(language)
.setRedirect(false)
.setDisambig(false)
.setNameSpaces(NameSpace.ARTICLE)
);
createBlackListSet();
}
private void createBlackListSet() throws FileNotFoundException {
blackListSet = new TIntHashSet();
if(blackListFilePath == null || blackListFilePath.equals("")) {
LOG.info("Skipping blacklist creation; no blacklist file specified.");
return;
}
File file = new File(blackListFilePath);
Scanner scanner = new Scanner(file);
while(scanner.hasNext()){
blackListSet.add(scanner.nextInt());
}
scanner.close();
}
@Override
public TIntFloatMap getVector(int pageId) throws DaoException {
TIntFloatMap vector = new TIntFloatHashMap(100);
if (pageId <= 0) {
throw new IllegalArgumentException("Invalid page id: " + pageId);
}
double norm2 = 0.0;
for (LocalLink link : linkDao.getLinks(language, pageId, outLinks)) {
int columnId = outLinks ? link.getDestId() : link.getSourceId();
if (columnId < 0) {
continue;
}
if(isBlacklisted(columnId)){
continue;
}
double value = 1;
if (weightByPopularity) {
value = numPages / getNumLinks(columnId);
if (logTransform) {
value = Math.log(value);
}
}
vector.put(columnId, (float) value);
norm2 += value * value;
}
final double n = norm2;
vector.transformValues(new TFloatFunction() {
@Override
public float execute(float value) {
return (float) (value / n);
}
});
return vector;
}
private boolean isBlacklisted(int pageId) {
return blackListSet.contains(pageId);
}
/**
* If outLinks is true, returns the number of links to the specified destination.
* Otherwise, returns number of links FROM the specified source.
* @param wpId
* @return
* @throws DaoException
*/
private int getNumLinks(int wpId) throws DaoException {
synchronized (linkCounts) {
if (linkCounts.containsKey(wpId)) {
return linkCounts.get(wpId);
}
}
int n;
if (outLinks) {
n = linkDao.getCount(new DaoFilter().setLanguages(language).setDestIds(wpId));
} else {
n = linkDao.getCount(new DaoFilter().setLanguages(language).setSourceIds(wpId));
}
synchronized (linkCounts) {
linkCounts.put(wpId, n);
}
return n;
}
@Override
public TIntFloatMap getVector(String phrase) {
throw new UnsupportedOperationException();
}
@Override
public List<Explanation> getExplanations(String phrase1, String phrase2, TIntFloatMap vector1, TIntFloatMap vector2, SRResult result) throws DaoException {
throw new UnsupportedOperationException();
}
@Override
public List<Explanation> getExplanations(int pageID1, int pageID2, TIntFloatMap vector1, TIntFloatMap vector2, SRResult result) throws DaoException {
LocalPage page1=pageDao.getById(language,pageID1);
LocalPage page2=pageDao.getById(language,pageID2);
Leaderboard lb = new Leaderboard(5); // TODO: make 5 configurable
for (int id : vector1.keys()) {
if (vector2.containsKey(id)) {
lb.tallyScore(id, vector1.get(id) * vector2.get(id));
}
}
SRResultList top = lb.getTop();
if (top.numDocs() == 0) {
return Arrays.asList(new Explanation("? and ? share no links", page1, page2));
}
List<Explanation> explanations = new ArrayList<Explanation>();
for (int i = 0; i < top.numDocs(); i++) {
LocalPage p = pageDao.getById(language, top.getId(i));
if (p == null) {
continue;
}
if (outLinks) {
explanations.add(new Explanation("Both ? and ? link to ?", page1, page2, p));
} else {
explanations.add(new Explanation("? links to both ? and ?", p, page1, page2));
}
}
return explanations;
}
public void setWeightByPopularity(boolean weightByPopularity) {
this.weightByPopularity = weightByPopularity;
}
public void setLogTransform(boolean logTransform) {
this.logTransform = logTransform;
}
public static class Provider extends org.wikibrain.conf.Provider<SparseVectorGenerator> {
public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
super(configurator, config);
}
@Override
public Class getType() {
return SparseVectorGenerator.class;
}
@Override
public String getPath() {
return "sr.metric.sparsegenerator";
}
@Override
public SparseVectorGenerator get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException {
if (!config.getString("type").equals("links")) {
return null;
}
if (runtimeParams == null || !runtimeParams.containsKey("language")) {
throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter");
}
Language language = Language.getByLangCode(runtimeParams.get("language"));
try {
LinkGenerator lg = new LinkGenerator(
language,
getConfigurator().get(LocalLinkDao.class),
getConfigurator().get(LocalPageDao.class),
config.getBoolean("outLinks"),
getConfig().get().getString("sr.blacklist.path")
);
if (config.hasPath("weightByPopularity")) {
lg.setWeightByPopularity(config.getBoolean("weightByPopularity"));
}
if (config.hasPath("logTransform")) {
lg.setLogTransform(config.getBoolean("logTransform"));
}
return lg;
} catch (Exception e) {
throw new ConfigurationException(e);
}
}
}
}