package org.wikibrain.sr; import com.typesafe.config.Config; import gnu.trove.list.TIntList; import gnu.trove.list.array.TIntArrayList; import gnu.trove.map.TIntDoubleMap; import gnu.trove.map.TIntIntMap; import gnu.trove.map.hash.TIntDoubleHashMap; import gnu.trove.map.hash.TIntIntHashMap; import gnu.trove.set.TIntSet; import gnu.trove.set.hash.TIntHashSet; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.dao.DaoFilter; import org.wikibrain.core.dao.LocalLinkDao; import org.wikibrain.core.dao.LocalPageDao; import org.wikibrain.core.lang.Language; import org.wikibrain.core.model.LocalLink; import org.wikibrain.sr.dataset.Dataset; import org.wikibrain.sr.disambig.Disambiguator; import org.wikibrain.sr.utils.Leaderboard; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Generates a sparse vector containing both inbound and outbound links * for a page. * * @author Shilad Sen */ public class DirectLinkMetric extends BaseSRMetric { private static final Logger LOG = LoggerFactory.getLogger(DirectLinkMetric.class); private final LocalLinkDao linkDao; public DirectLinkMetric(String name, Language language, LocalPageDao dao, LocalLinkDao linkDao, Disambiguator dab) { super(name, language, dao,dab); this.linkDao = linkDao; } @Override public SRConfig getConfig() { SRConfig config = new SRConfig(); config.maxScore = 1.1f; config.minScore = 0; return config; } @Override public SRResult similarity(int pageId1, int pageId2, boolean explanations) throws DaoException { int count = 0; int links1[] = getLinks(pageId1, true); int links2[] = getLinks(pageId2, true); count += hasLink(links1, pageId2); count += hasLink(links2, pageId1); return new SRResult(normalize(1.0 * count / 2.0)); } @Override public double[][] cosimilarity(int rowIds[], int columnIds[]) throws DaoException { int [][] rowLinks = new int[rowIds.length][]; int [][] colLinks = new int[columnIds.length][]; for (int i = 0; i < rowIds.length; i++) rowLinks[i] = getLinks(rowIds[i], true); for (int i = 0; i < columnIds.length; i++) colLinks[i] = getLinks(columnIds[i], true); double result[][] = new double[rowIds.length][columnIds.length]; for (int i = 0; i < rowIds.length; i++) { for (int j = 0; j < columnIds.length; j++) { int has1 = hasLink(rowLinks[i], columnIds[j]); int has2 = hasLink(colLinks[j], rowIds[i]); result[i][j] = normalize(0.5 * has1 + 0.5 * has2); } } return result; } @Override public SRResultList mostSimilar(int pageId, int maxResults, TIntSet validIds) throws DaoException { TIntIntMap scores = new TIntIntHashMap(); for (int id : getLinks(pageId, true)) { if (validIds == null || validIds.contains(id)) scores.adjustOrPutValue(id, 1, 1); } for (int id : getLinks(pageId, false)) { if (validIds == null || validIds.contains(id)) scores.adjustOrPutValue(id, 1, 1); } Leaderboard leaderboard = new Leaderboard(maxResults); for (int id : scores.keys()) { leaderboard.tallyScore(id, scores.get(id) / 2.0); } return normalize(leaderboard.getTop()); } private int hasLink(int [] links, int targetId) { return (Arrays.binarySearch(links, targetId) >= 0) ? 1 : 0; } private int[] getLinks(int pageId1, boolean outLinks) throws DaoException { TIntList result = new TIntArrayList(); for (LocalLink ll : linkDao.getLinks(getLanguage(), pageId1, outLinks)) { result.add(ll.getLocalId()); } result.sort(); return result.toArray(); } public static class Provider extends org.wikibrain.conf.Provider<SRMetric> { public Provider(Configurator configurator, Configuration config) throws ConfigurationException { super(configurator, config); } @Override public Class getType() { return SRMetric.class; } @Override public String getPath() { return "sr.metric.local"; } @Override public SRMetric get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException { if (!config.getString("type").equals("directlink")) { return null; } if (runtimeParams == null || !runtimeParams.containsKey("language")) { throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter"); } Language language = Language.getByLangCode(runtimeParams.get("language")); LocalLinkDao linkDao = getConfigurator().get(LocalLinkDao.class); Disambiguator dab = getConfigurator().get(Disambiguator.class, config.getString("disambiguator"), "language", language.getLangCode()); DirectLinkMetric mw = new DirectLinkMetric( name, language, getConfigurator().get(LocalPageDao.class), linkDao, dab ); configureBase(getConfigurator(), mw, config); return mw; } } }