package org.wikibrain.sr.disambig; import com.google.common.collect.Maps; import com.typesafe.config.Config; import com.typesafe.config.ConfigValueFactory; import org.apache.commons.collections.CollectionUtils; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.lang.Language; import org.wikibrain.core.lang.LocalId; import org.wikibrain.core.lang.LocalString; import org.wikibrain.phrases.PhraseAnalyzer; import org.wikibrain.sr.SRMetric; import org.wikibrain.utils.WbMathUtils; import org.wikibrain.utils.WpCollectionUtils; import java.util.*; public class SimilarityDisambiguator extends Disambiguator { public static final int DEFAULT_NUM_CANDIDATES = 5; protected final PhraseAnalyzer phraseAnalyzer; private int numCandidates = DEFAULT_NUM_CANDIDATES; private Language language; /** * Algorithms for disambiguating similar phrases */ public static enum Criteria { SUM, // select senses with highest sum of popularity + similarity PRODUCT, // select senses with highest sum of popularity * similarity POPULARITY, // select most popular senses SIMILARITY // select most similar senses } // Method for disambiguating similar phrases private Criteria criteria = Criteria.SUM; private final SRMetric metric; public SimilarityDisambiguator(PhraseAnalyzer phraseAnalyzer, SRMetric metric) { this.phraseAnalyzer = phraseAnalyzer; this.metric = metric; this.language = metric.getLanguage(); } @Override public List<LinkedHashMap<LocalId, Float>> disambiguate(List<LocalString> phrases, Set<LocalString> context) throws DaoException { List<LinkedHashMap<LocalId, Float>> results = new ArrayList<LinkedHashMap<LocalId, Float>>(); List<LocalString> allPhrases = new ArrayList<LocalString>( (context == null) ? phrases : CollectionUtils.union(phrases, context)); // Step 0: calculate most frequent candidate senses for each phrase Map<LocalString, LinkedHashMap<LocalId, Float>> candidates = Maps.newHashMap(); for (LocalString s : allPhrases) { candidates.put(s, phraseAnalyzer.resolve(s.getLanguage(), s.getString(), numCandidates)); } // Skip using the sr metric at all! if (criteria == Criteria.POPULARITY) { for (LocalString phrase : phrases) { LinkedHashMap<LocalId, Float> m = new LinkedHashMap<LocalId, Float>(); for (LocalId li : candidates.get(phrase).keySet()) { m.put(li, candidates.get(phrase).get(li)); } results.add(m); } return results; } // Step 2: calculate the sum of cosimilarities for each page Map<LocalId, Float> pageSims = getCosimilaritySums(candidates); // Step 3: multiply background probability by sim sums, choose best product List<LinkedHashMap<LocalId, Float>> result = new ArrayList<LinkedHashMap<LocalId, Float>>(); for (LocalString ls : phrases) { Map<LocalId, Float> phraseCands = candidates.get(ls); LinkedHashMap<LocalId, Float> pageResult = selectFinalPhraseSenses(pageSims, phraseCands); result.add(pageResult); } return result; } private LinkedHashMap<LocalId, Float> selectFinalPhraseSenses(Map<LocalId, Float> pageSims, Map<LocalId, Float> phrasePops) { if (phrasePops == null || phrasePops.isEmpty()) { return null; } double sum = 0.0; for (LocalId lp : phrasePops.keySet()) { float pop = phrasePops.get(lp); float sim = pageSims.get(lp); float score; switch (criteria) { case POPULARITY: score = pop; break; case SIMILARITY: score = sim; break; case SUM: score = pop + sim; break; case PRODUCT: score = pop * sim; break; default: throw new IllegalStateException(); } phrasePops.put(lp, score); sum += score; } LinkedHashMap<LocalId, Float> pageResult = new LinkedHashMap<LocalId, Float>(); for (LocalId key : WpCollectionUtils.sortMapKeys(phrasePops, true)) { pageResult.put(key, (float)(phrasePops.get(key) / sum)); } return pageResult; } /** * Return the sum of cosimilarity scores for all unique pages among the candidates. * @param candidates * @return * @throws DaoException */ private Map<LocalId, Float> getCosimilaritySums(Map<LocalString, LinkedHashMap<LocalId, Float>> candidates) throws DaoException { // Step 1: compute the page cosimilarity matrix Set<LocalId> uniques = new HashSet<LocalId>(); for (LinkedHashMap<LocalId, Float> prob : candidates.values()) { uniques.addAll(prob.keySet()); } List<LocalId> pages = new ArrayList<LocalId>(uniques); double[][] cosim; if (pages.isEmpty()){ cosim = new double[0][0]; } else { if (!pages.get(0).getLanguage().equals(language)) { throw new DaoException("Expected language " + language + ", found " + pages.get(0).getLanguage()); } int[] pageIds = new int[pages.size()]; for (int i=0; i<pages.size(); i++){ pageIds[i] = pages.get(i).getId(); } cosim = metric.cosimilarity(pageIds); } // Step 2: calculate the sum of cosimilarities for each page Map<LocalId, Float> pageSims = new HashMap<LocalId, Float>(); for (int i = 0; i < pages.size(); i++) { double sum = 0.0; for (int j = 0; j < pages.size(); j++) { if (i != j && WbMathUtils.isReal(cosim[i][j])) { sum += Math.max(0, cosim[i][j]); // Hack: no negative numbers } } // add 0.0001 to give every candidate a tiny chance and avoid divide by zero errors when there are no good options pageSims.put(pages.get(i), (float)(sum + 0.0001)); } return pageSims; } public Criteria getCriteria() { return criteria; } public void setCriteria(Criteria criteria) { this.criteria = criteria; } public int getNumCandidates() { return numCandidates; } public void setNumCandidates(int numCandidates) { this.numCandidates = numCandidates; } public static class Provider extends org.wikibrain.conf.Provider<Disambiguator>{ public Provider(Configurator configurator, Configuration config) throws ConfigurationException { super(configurator,config); } @Override public Class getType(){ return Disambiguator.class; } @Override public String getPath(){ return "sr.disambig"; } @Override public Disambiguator get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException{ if (!config.getString("type").equals("similarity")){ return null; } if (runtimeParams == null || !runtimeParams.containsKey("language")){ throw new IllegalArgumentException("SimpleMilneWitten requires 'language' runtime parameter."); } Language lang = Language.getByLangCode(runtimeParams.get("language")); PhraseAnalyzer pa = getConfigurator().get(PhraseAnalyzer.class, config.getString("phraseAnalyzer")); // Create override config for sr metric and load it. String srName = config.getString("metric"); Config newConfig = getConfig().get().getConfig("sr.metric.local." + srName) .withValue("disambiguator", ConfigValueFactory.fromAnyRef("topResult")); Map<String, String> srRuntimeParams = new HashMap<String, String>(); srRuntimeParams.put("language", lang.getLangCode()); SRMetric sr = getConfigurator().construct(SRMetric.class, srName, newConfig, srRuntimeParams); SimilarityDisambiguator dab = new SimilarityDisambiguator(pa, sr); if (config.hasPath("criteria")) { dab.setCriteria(Criteria.valueOf(config.getString("criteria").toUpperCase())); } return dab; } } }