package org.wikibrain.sr.milnewitten; import com.google.common.collect.Maps; import com.typesafe.config.Config; import com.typesafe.config.ConfigValueFactory; import org.apache.commons.collections.CollectionUtils; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.dao.DaoFilter; import org.wikibrain.core.dao.LocalPageDao; import org.wikibrain.core.lang.Language; import org.wikibrain.core.lang.LocalId; import org.wikibrain.core.lang.LocalString; import org.wikibrain.core.model.NameSpace; import org.wikibrain.phrases.PhraseAnalyzer; import org.wikibrain.sr.SRMetric; import org.wikibrain.sr.disambig.Disambiguator; import org.wikibrain.utils.WbMathUtils; import org.wikibrain.utils.WpCollectionUtils; import java.util.*; /** * @author Shilad Sen * * A literal implementation of the disambiguator described in * http://www.aaai.org/Papers/Workshops/2008/WS-08-15/WS08-15-005.pdf * * This is naively extended to support multiple strings and contexts. */ public class MilneWittenDisambiguator extends Disambiguator { private final Language language; private final LocalPageDao pageDao; private final PhraseAnalyzer analyzer; private final SRMetric metric; private final int numPages; /** * Construct a new disambiguator that uses a particular metric. * * @param pageDao * @param analyzer * @param metric * @throws DaoException */ public MilneWittenDisambiguator(LocalPageDao pageDao, PhraseAnalyzer analyzer, SRMetric metric) throws DaoException { this.language = metric.getLanguage(); this.pageDao = pageDao; this.analyzer = analyzer; this.metric = metric; this.numPages = pageDao.getCount( new DaoFilter().setLanguages(language) .setNameSpaces(NameSpace.ARTICLE) .setRedirect(false) .setDisambig(false)); } /** * Given * @param phrases The target phrases being disambiguated. * @param context Other phrases (in the same language as the target phrase) * related to the target phrase being disambiguated that may * aid disambiguation. * @return * @throws DaoException */ @Override public List<LinkedHashMap<LocalId, Float>> disambiguate(List<LocalString> phrases, Set<LocalString> context) throws DaoException { // Collect all phrases List<LocalString> allPhrases = new ArrayList<LocalString>( (context == null) ? phrases : CollectionUtils.union(phrases, context)); // Step 1: calculate most frequent candidate senses for each phrase Map<LocalString, LinkedHashMap<LocalId, Float>> candidates = Maps.newHashMap(); for (LocalString s : allPhrases) { if (!s.getLanguage().equals(language)) { throw new IllegalArgumentException("Disambiguator only supports language " + language); } candidates.put(s, analyzer.resolve(s.getLanguage(), s.getString(), 100)); } // Step 1.5: Build mapping from local id to phrases that are candidates Map<LocalId, Set<LocalString>> idsToPhrases = new HashMap<LocalId, Set<LocalString>>(); for (LocalString s : candidates.keySet()) { for (LocalId lid : candidates.get(s).keySet()) { if (!idsToPhrases.containsKey(lid)) { idsToPhrases.put(lid, new HashSet<LocalString>()); } idsToPhrases.get(lid).add(s); } } // Step 2: calculate the sum of cosimilarities for each page Map<LocalId, Float> pageSims = getCosimilaritySums(candidates); // Step 3: Choose the best options for each phrase List<LinkedHashMap<LocalId, Float>> result = new ArrayList<LinkedHashMap<LocalId, Float>>(); for (LocalString p : phrases) { result.add( disambiguateOnePhrase(p, candidates.get(p), idsToPhrases, pageSims)); } return result; } private LinkedHashMap<LocalId, Float> disambiguateOnePhrase(LocalString phrase, LinkedHashMap<LocalId, Float> candidates, Map<LocalId, Set<LocalString>> idsToPhrases, Map<LocalId, Float> pageSims) throws DaoException { // Identify the highest similarity for each page float maxSimilarity = Float.NEGATIVE_INFINITY; for (LocalId lid : candidates.keySet()) { maxSimilarity = Math.max(maxSimilarity, pageSims.get(lid)); } // Identify the most popular senses within 40% of the top similarity Map<LocalId, Float> scores = new HashMap<LocalId, Float>(); double scoreSum = 0.0; for (LocalId lid : candidates.keySet()) { double sim = pageSims.get(lid); if (sim < 0.4 * maxSimilarity) { continue; } double pop = candidates.get(lid); double phraseBonus = 0.0; /* TODO: figure out phraseBonus for multiple phrases int numPhrases = 0; int sumCounts = 0; for (LocalString ls2 : idsToPhrases.get(lid)) { numPhrases++; sumCounts += getPhraseCount(phrase.getString() + " " + ls2.getString()); sumCounts += getPhraseCount(ls2.getString() + " " + phrase.getString()); } if (sumCounts > 0) { int maxExpectedValue = numPhrases * numPages / 50; phraseBonus = Math.log(sumCounts + 1) / Math.log(maxExpectedValue); phraseBonus = Math.min(0.5, phraseBonus); } System.err.println("phrase bonus for " + phrase + ", " + idsToPhrases.get(lid) + " is " + phraseBonus); */ double score = pop + phraseBonus; scores.put(lid, (float) score); scoreSum += score; } LinkedHashMap<LocalId, Float> pageResult = new LinkedHashMap<LocalId, Float>(); for (LocalId key : WpCollectionUtils.sortMapKeys(scores, true)) { pageResult.put(key, (float) (scores.get(key) / scoreSum)); } return pageResult; } /* private int getPhraseCount(String phrase) throws DaoException { PrunedCounts<Integer> pages = analyzer.getDao().getPhraseCounts(language, phrase, 1); if (pages == null) { return 0; } else { return pages.getTotal(); } } */ /** * Return the sum of cosimilarity scores for all unique pages among the candidates. * @param candidates * @return * @throws DaoException */ private Map<LocalId, Float> getCosimilaritySums(Map<LocalString, LinkedHashMap<LocalId, Float>> candidates) throws DaoException { // Step 1: compute the page cosimilarity matrix Set<LocalId> uniques = new HashSet<LocalId>(); for (LinkedHashMap<LocalId, Float> prob : candidates.values()) { uniques.addAll(prob.keySet()); } List<LocalId> pages = new ArrayList<LocalId>(uniques); double[][] cosim; if (pages.isEmpty()){ cosim = new double[0][0]; } else { int[] pageIds = new int[pages.size()]; for (int i=0; i<pages.size(); i++){ pageIds[i] = pages.get(i).getId(); } cosim = metric.cosimilarity(pageIds); } // Step 2: calculate the sum of cosimilarities for each page Map<LocalId, Float> pageSims = new HashMap<LocalId, Float>(); for (int i = 0; i < pages.size(); i++) { double sum = 0.0; for (int j = 0; j < pages.size(); j++) { if (i != j && WbMathUtils.isReal(cosim[i][j])) { sum += Math.max(0, cosim[i][j]); // Hack: no negative numbers } } // add 0.0001 to give every candidate a tiny chance and avoid divide by zero errors when there are no good options pageSims.put(pages.get(i), (float)(sum + 0.0001)); } return pageSims; } public static class Provider extends org.wikibrain.conf.Provider<Disambiguator>{ public Provider(Configurator configurator, Configuration config) throws ConfigurationException { super(configurator,config); } @Override public Class getType(){ return Disambiguator.class; } @Override public String getPath(){ return "sr.disambig"; } @Override public Disambiguator get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException{ if (!config.getString("type").equals("milnewitten")){ return null; } if (runtimeParams == null || !runtimeParams.containsKey("language")){ throw new IllegalArgumentException("SimpleMilneWitten requires 'language' runtime parameter."); } Language lang = Language.getByLangCode(runtimeParams.get("language")); PhraseAnalyzer pa = getConfigurator().get(PhraseAnalyzer.class, config.getString("phraseAnalyzer")); LocalPageDao pageDao = getConfigurator().get(LocalPageDao.class); // Create override config for sr metric and load it. String srName = config.getString("metric"); Config newConfig = getConfig().get().getConfig("sr.metric.local." + srName) .withValue("disambiguator", ConfigValueFactory.fromAnyRef("topResult")); Map<String, String> srRuntimeParams = new HashMap<String, String>(); srRuntimeParams.put("language", lang.getLangCode()); SRMetric sr = getConfigurator().construct(SRMetric.class, srName, newConfig, srRuntimeParams); try { return new MilneWittenDisambiguator(pageDao, pa, sr); } catch (DaoException e) { throw new ConfigurationException(e); } } } }