package org.wikibrain.sr.milnewitten;
import com.typesafe.config.Config;
import gnu.trove.map.TIntDoubleMap;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.sr.*;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.disambig.Disambiguator;
import org.wikibrain.sr.utils.Leaderboard;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Generates a sparse vector containing both inbound and outbound links
* for a page.
*
* @author Shilad Sen
*/
public class MilneWittenMetric extends BaseSRMetric {
private static final Logger LOG = LoggerFactory.getLogger(MilneWittenMetric.class);
private final SRMetric inlink;
private final SRMetric outlink;
private boolean trainSubmetrics =true;
public MilneWittenMetric(String name, Language language, LocalPageDao dao, SRMetric inlink, SRMetric outlink, Disambiguator dab) {
super(name, language, dao,dab);
this.inlink = inlink;
this.outlink = outlink;
}
@Override
public SRConfig getConfig() {
SRConfig config = new SRConfig();
config.maxScore = 1.1f;
config.minScore = 0;
return config;
}
@Override
public SRResult similarity(int pageId1, int pageId2, boolean explanations) throws DaoException {
SRResult r1 = inlink.similarity(pageId1, pageId2, explanations);
SRResult r2 = outlink.similarity(pageId1, pageId2, explanations);
if (r1 == null || r2 == null || !r1.isValid() || !r2.isValid()) {
return new SRResult(Double.NaN);
} else {
SRResult finalResult=new SRResult(0.5 * r1.getScore() + 0.5 * r2.getScore());
if (explanations) {
List<Explanation> explanationList = new ArrayList<Explanation>();
explanationList.addAll(r1.getExplanations());
explanationList.addAll(r2.getExplanations());
finalResult.setExplanations(explanationList);
}
return normalize(finalResult);
}
}
@Override
public double[][] cosimilarity(int rowIds[], int columnIds[]) throws DaoException {
double [][] cm1 = inlink.cosimilarity(rowIds, columnIds);
double [][] cm2 = outlink.cosimilarity(rowIds, columnIds);
for (int i = 0; i < rowIds.length; i++) {
for (int j = 0; j < columnIds.length; j++) {
double s1 = cm1[i][j];
double s2 = cm2[i][j];
if (Double.isNaN(s1) || Double.isNaN(s2) || Double.isInfinite(s1) || Double.isInfinite(s2)) {
cm1[i][j] = Double.NaN;
} else {
cm1[i][j] = normalize(s1 * 0.5 + s2 * 0.5);
}
}
}
return cm1;
}
public void setTrainSubmetrics(boolean train){
trainSubmetrics = train;
}
@Override
public synchronized void trainSimilarity(Dataset dataset) throws DaoException {
if(trainSubmetrics) {
inlink.trainSimilarity(dataset);
outlink.trainSimilarity(dataset);
}
super.trainSimilarity(dataset);
}
@Override
public synchronized void trainMostSimilar(Dataset dataset, int numResults, TIntSet validIds){
if(trainSubmetrics){
inlink.trainMostSimilar(dataset, numResults, validIds);
outlink.trainMostSimilar(dataset, numResults, validIds);
}
super.trainMostSimilar(dataset, numResults, validIds);
}
@Override
public void write() throws IOException {
inlink.write();
outlink.write();
super.write();
}
@Override
public void read() throws IOException {
inlink.read();
outlink.read();
super.read();
}
@Override
public SRResultList mostSimilar(int pageId, int maxResults, TIntSet validIds) throws DaoException {
SRResultList l1 = inlink.mostSimilar(pageId, maxResults * 2, validIds);
TIntDoubleMap scores = new TIntDoubleHashMap(maxResults * 4);
TIntSet inList1 = new TIntHashSet();
if (l1 != null) {
for (int i = 0; i < l1.numDocs(); i++) {
double s = l1.getScore(i);
if (!Double.isInfinite(s) && !Double.isNaN(s)) {
scores.adjustOrPutValue(l1.getId(i), 0.5 * s, 0.5 * s);
inList1.add(l1.getId(i));
}
}
}
SRResultList l2 = outlink.mostSimilar(pageId, maxResults * 2, validIds);
TIntSet inList2 = new TIntHashSet();
if (l2 != null) {
for (int i = 0; i < l2.numDocs(); i++) {
double s = l2.getScore(i);
if (!Double.isInfinite(s) && !Double.isNaN(s)) {
scores.adjustOrPutValue(l2.getId(i), 0.5 * s, 0.5 * s);
inList2.add(l2.getId(i));
}
}
}
double missingScore1 = (l1 == null) ? 0.0 : l1.getMissingScore();
double missingScore2 = (l2 == null) ? 0.0 : l2.getMissingScore();
for (int p1 : inList1.toArray()) {
if (!inList2.contains(p1)) {
scores.adjustValue(p1, 0.5 * missingScore2);
}
}
for (int p2 : inList2.toArray()) {
if (!inList1.contains(p2)) {
scores.adjustValue(p2, 0.5 * missingScore1);
}
}
Leaderboard leaderboard = new Leaderboard(maxResults);
for (int id : scores.keys()) {
leaderboard.tallyScore(id, scores.get(id));
}
return normalize(leaderboard.getTop());
}
public static class Provider extends org.wikibrain.conf.Provider<SRMetric> {
public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
super(configurator, config);
}
@Override
public Class getType() {
return SRMetric.class;
}
@Override
public String getPath() {
return "sr.metric.local";
}
@Override
public SRMetric get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException {
if (!config.getString("type").equals("milnewitten")) {
return null;
}
if (runtimeParams == null || !runtimeParams.containsKey("language")) {
throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter");
}
Language language = Language.getByLangCode(runtimeParams.get("language"));
SRMetric inlink = getConfigurator().get(
SRMetric.class, config.getString("inlink"),
"language", language.getLangCode());
SRMetric outlink = getConfigurator().get(
SRMetric.class, config.getString("outlink"),
"language", language.getLangCode());
Disambiguator dab = getConfigurator().get(Disambiguator.class, config.getString("disambiguator"), "language", language.getLangCode());
MilneWittenMetric mw = new MilneWittenMetric(
name,
language,
getConfigurator().get(LocalPageDao.class),
inlink,
outlink,
dab
);
configureBase(getConfigurator(), mw, config);
return mw;
}
}
}