package org.wikibrain.sr.vector; import com.typesafe.config.Config; import gnu.trove.procedure.TIntProcedure; import gnu.trove.set.TIntSet; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.dao.LocalPageDao; import org.wikibrain.core.lang.Language; import org.wikibrain.matrix.DenseMatrix; import org.wikibrain.matrix.DenseMatrixRow; import org.wikibrain.matrix.knn.KNNFinder; import org.wikibrain.matrix.knn.Neighborhood; import org.wikibrain.matrix.knn.RandomProjectionKNNFinder; import org.wikibrain.sr.BaseSRMetric; import org.wikibrain.sr.SRMetric; import org.wikibrain.sr.SRResult; import org.wikibrain.sr.SRResultList; import org.wikibrain.sr.dataset.Dataset; import org.wikibrain.sr.disambig.Disambiguator; import org.wikibrain.sr.utils.Leaderboard; import org.wikibrain.sr.utils.SimUtils; import java.io.File; import java.io.IOException; import java.util.HashMap; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * An SR metric that represents phrases and pages using dense numeric vectors. * SR scores are the result of cosine similarity between vectors. Word2Vec uses this * representations. * * <p>The metric requires a DenseVectorGenerator class that generates the vectors.</p> * * @author Shilad Sen * @see org.wikibrain.sr.vector.DenseVectorGenerator */ public class DenseVectorSRMetric extends BaseSRMetric { private static final Logger LOG = LoggerFactory.getLogger(DenseVectorSRMetric.class); protected final DenseVectorGenerator generator; protected final SRConfig config; private DenseMatrix articleFeatures; private KNNFinder accelerator; // Multiple of requested mostSimilar results that will be considered private double acceleratorMultiplier = 100.0; private int minAcceleratorCandidates = 500; public DenseVectorSRMetric(String name, Language language, LocalPageDao dao, Disambiguator disambig, DenseVectorGenerator generator) { super(name, language, dao, disambig); this.generator = generator; this.articleFeatures = generator.getFeatureMatrix(); if (articleFeatures == null) { // TODO: build the article features if necessary. throw new IllegalArgumentException(); } this.config = new SRConfig(); } @Override public SRResult similarity(String phrase1, String phrase2, boolean explanations) throws DaoException { float [] vector1 = null; float [] vector2 = null; // try using phrases directly try { vector1 = generator.getVector(phrase1); vector2 = generator.getVector(phrase2); } catch (UnsupportedOperationException e) { // try using other methods } if (vector1 == null || vector2 == null) { return super.similarity(phrase1, phrase2, explanations); } else { SRResult result= new SRResult(SimUtils.cosineSimilarity(vector1, vector2)); if(explanations) { result.setExplanations(generator.getExplanations(phrase1, phrase2, vector1, vector2, result)); } return normalize(result); } } @Override public SRResult similarity(int pageId1, int pageId2, boolean explanations) throws DaoException { try { float [] v1 = getPageVector(pageId1); float [] v2 = getPageVector(pageId2); SRResult result = new SRResult(normalize(SimUtils.cosineSimilarity(v1, v2))); if (explanations) { result.setExplanations(generator.getExplanations(pageId1, pageId2, v1, v2, result)); } return result; } catch (IOException e) { throw new DaoException(e); } } @Override public SRResultList mostSimilar(String phrase, int maxResults, TIntSet validIds) throws DaoException { try { // try using phrases directly float [] vector = generator.getVector(phrase); return mostSimilar(vector, maxResults, validIds); } catch (UnsupportedOperationException e) { // try using other methods return super.mostSimilar(phrase, maxResults, validIds); } catch (IOException e) { throw new DaoException(e); } } @Override public SRResultList mostSimilar(int pageId, int maxResults, TIntSet validIds) throws DaoException { try { return mostSimilar(getPageVector(pageId), maxResults, validIds); } catch (IOException e) { throw new DaoException(e); } } public SRResultList mostSimilar(final float [] vector, int maxResults, TIntSet validIds) throws IOException { if (vector == null) { return new SRResultList(0); } SRResultList result; if (accelerator != null) { int n = (int) Math.max(minAcceleratorCandidates, maxResults * acceleratorMultiplier); Neighborhood nhood = accelerator.query(vector, maxResults, n, validIds); result = new SRResultList(nhood.size()); for (int i = 0; i < nhood.size(); i++) { result.set(i, nhood.getId(i), nhood.getScore(i)); } } else { final Leaderboard board = new Leaderboard(maxResults); if (validIds == null) { for (DenseMatrixRow row : articleFeatures) { board.tallyScore(row.getRowIndex(), SimUtils.cosineSimilarity(row.getValues(), vector)); } } else { validIds.forEach(new TIntProcedure() { @Override public boolean execute(int id) { try { float [] v = getPageVector(id); if (v != null) { board.tallyScore(id, SimUtils.cosineSimilarity(v, vector)); } } catch (Exception e) { LOG.warn("similarity for " + id + " failed: ", e); } return true; } }); } result = board.getTop(); } return normalize(result); } /** * Train the similarity() function. * The KnownSims may already be associated with Wikipedia ids (check wpId1 and wpId2). * * @param dataset A gold standard dataset */ @Override public void trainSimilarity(Dataset dataset) throws DaoException { super.trainSimilarity(dataset); // DO nothing, for now. } /** * @see org.wikibrain.sr.SRMetric#trainMostSimilar(org.wikibrain.sr.dataset.Dataset, int, gnu.trove.set.TIntSet) */ @Override public void trainMostSimilar(Dataset dataset, int numResults, TIntSet validIds) { try { RandomProjectionKNNFinder knn = new RandomProjectionKNNFinder(articleFeatures); knn.build(); accelerator = knn; } catch (IOException e) { throw new IllegalStateException("Unexpected exception: " + e); } super.trainMostSimilar(dataset, numResults, validIds); } @Override public double[][] cosimilarity(int pageIds[]) throws DaoException { return cosimilarity(pageIds, pageIds); } @Override public double[][] cosimilarity(String phrases[]) throws DaoException { return cosimilarity(phrases, phrases); } /** * Calculates the cosimilarity matrix between phrases. * First tries to use generator to get phrase vectors directly, but some generators will not support this. * Falls back on disambiguating phrase vectors to page ids. * * @param rowPhrases * @param colPhrases * @return * @throws org.wikibrain.core.dao.DaoException */ @Override public double[][] cosimilarity(String rowPhrases[], String colPhrases[]) throws DaoException { if (rowPhrases.length == 0 || colPhrases.length == 0) { return new double[rowPhrases.length][colPhrases.length]; } float [][] rowVectors = new float[rowPhrases.length][]; float [][] colVectors = new float[colPhrases.length][]; try { for (int i = 0; i < rowPhrases.length; i++) { rowVectors[i] = generator.getVector(rowPhrases[i]); } for (int i = 0; i < colPhrases.length; i++) { colVectors[i] = generator.getVector(colPhrases[i]); } } catch (UnsupportedOperationException e) { return super.cosimilarity(rowPhrases, colPhrases); } double [][] result = new double[rowVectors.length][colVectors.length]; for (int i = 0; i < rowVectors.length; i++) { for (int j = 0; j < colVectors.length; j++) { result[i][j] = normalize(SimUtils.cosineSimilarity(rowVectors[i], colVectors[j])); } } return result; } /** * Computes the cosimilarity matrix between pages. * @param rowIds * @param colIds * @return * @throws org.wikibrain.core.dao.DaoException */ @Override public double[][] cosimilarity(int rowIds[], int colIds[]) throws DaoException { try { if (rowIds.length == 0 || colIds.length == 0) { return new double[rowIds.length][colIds.length]; } float[][] rowVectors = new float[rowIds.length][]; float[][] colVectors = new float[colIds.length][]; for (int i = 0; i < rowIds.length; i++) { rowVectors[i] = getPageVector(rowIds[i]); } for (int i = 0; i < colIds.length; i++) { colVectors[i] = getPageVector(colIds[i]); } double[][] result = new double[rowVectors.length][colVectors.length]; for (int i = 0; i < rowVectors.length; i++) { for (int j = 0; j < colVectors.length; j++) { result[i][j] = normalize(SimUtils.cosineSimilarity(rowVectors[i], colVectors[j])); } } return result; } catch (IOException e) { throw new DaoException(e); } } @Override public void read() throws IOException { super.read(); RandomProjectionKNNFinder knn = new RandomProjectionKNNFinder(articleFeatures); if (knn.load(new File(getDataDir(), "knn.bin"))) { accelerator = knn; } } @Override public void write() throws IOException { super.write(); if (accelerator != null) accelerator.save(new File(getDataDir(), "knn.bin")); } /** * Returns the vector associated with a page, or null. * @param pageId * @return */ public float[] getPageVector(int pageId) throws IOException { if (articleFeatures == null) { try { return generator.getVector(pageId); } catch (DaoException e) { throw new IOException(e); } } else { DenseMatrixRow row = articleFeatures.getRow(pageId); return row == null ? null : row.getValues(); } } public DenseVectorGenerator getGenerator() { return generator; } public void setAccelerator(KNNFinder accelerator) { this.accelerator = accelerator; } public void setAcceleratorMultiplier(double acceleratorMultiplier) { this.acceleratorMultiplier = acceleratorMultiplier; } public void setMinAcceleratorCandidates(int minAcceleratorCandidates) { this.minAcceleratorCandidates = minAcceleratorCandidates; } @Override public SRConfig getConfig() { return config; } public static class Provider extends org.wikibrain.conf.Provider<SRMetric> { public Provider(Configurator configurator, Configuration config) throws ConfigurationException { super(configurator, config); } @Override public Class getType() { return SRMetric.class; } @Override public String getPath() { return "sr.metric.local"; } @Override public SRMetric get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException { if (!config.getString("type").equals("densevector")) { return null; } if (runtimeParams == null || !runtimeParams.containsKey("language")){ throw new IllegalArgumentException("Monolingual requires 'language' runtime parameter."); } Language language = Language.getByLangCode(runtimeParams.get("language")); Map<String, String> params = new HashMap<String, String>(); params.put("language", language.getLangCode()); DenseVectorGenerator generator = getConfigurator().construct( DenseVectorGenerator.class, null, config.getConfig("generator"), params); DenseVectorSRMetric sr = new DenseVectorSRMetric( name, language, getConfigurator().get(LocalPageDao.class,config.getString("pageDao")), getConfigurator().get(Disambiguator.class,config.getString("disambiguator"),"language", language.getLangCode()), generator ); configureBase(getConfigurator(), sr, config); return sr; } } }