package org.wikibrain.sr.vector;
import com.typesafe.config.Config;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.TIntObjectMap;
import gnu.trove.map.hash.TIntFloatHashMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.DaoFilter;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.model.LocalPage;
import org.wikibrain.core.model.NameSpace;
import org.wikibrain.matrix.*;
import org.wikibrain.sr.*;
import org.wikibrain.sr.BaseSRMetric;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.disambig.Disambiguator;
import org.wikibrain.utils.*;
import java.io.File;
import java.io.IOException;
import java.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* An SR metric that represents phrases and pages using sparse numeric vectors.
* SR scores are the result of some similarity metric. MilneWitten, ESA, and
* Pairwise metrics all use this representation.
*
* <p>
* The metric requires two subcomponents:
* <ul>
* <li>A VectorGenerator class that generates the sparse vectors.</li>
* <li>A VectorSimilarity class that generates SR scores given two vectors.</li>
* </ul>
*
* <p>
*
* This class also manages a feature matrix and transpose. The matrix is required
* for calls to mostSimilar. It is not required for calls to similarity(), but will
* be used to speed them up if available. The matrix is built when trainMostSimilar()
* is called, but can also be explicitly built by calling
* buildFeatureAndTransposeMatrices().
*
* @author Shilad Sen
* @see SparseVectorGenerator
* @see org.wikibrain.sr.vector.VectorSimilarity
*/
public class SparseVectorSRMetric extends BaseSRMetric {
private static final Logger LOG = LoggerFactory.getLogger(SparseVectorSRMetric.class);
protected final SparseVectorGenerator generator;
protected final VectorSimilarity similarity;
protected final SRConfig config;
private FeatureFilter featureFilter = null;
private SparseMatrix featureMatrix;
private SparseMatrix transposeMatrix;
public SparseVectorSRMetric(String name, Language language, LocalPageDao dao, Disambiguator disambig, SparseVectorGenerator generator, VectorSimilarity similarity) {
super(name, language, dao, disambig);
this.generator = generator;
this.similarity = similarity;
this.config = new SRConfig();
this.config.minScore = (float) similarity.getMinValue();
this.config.maxScore = (float) similarity.getMaxValue();
}
@Override
public SRResult similarity(String phrase1, String phrase2, boolean explanations) throws DaoException {
if (featureFilter != null) {
throw new UnsupportedOperationException();
}
TIntFloatMap vector1 = null;
TIntFloatMap vector2 = null;
// try using phrases directly
try {
vector1 = generator.getVector(phrase1);
vector2 = generator.getVector(phrase2);
} catch (UnsupportedOperationException e) {
// try using other methods
}
if (vector1 == null || vector2 == null) {
return super.similarity(phrase1, phrase2, explanations);
} else {
SRResult result= new SRResult(similarity.similarity(vector1, vector2));
if(explanations) {
result.setExplanations(generator.getExplanations(phrase1, phrase2, vector1, vector2, result));
}
return normalize(result);
}
}
@Override
public SRResult similarity(int pageId1, int pageId2, boolean explanations) throws DaoException {
try {
if (hasFeatureMatrix()) {
// Optimization that matters: Avoid building page vectors if possible.
SparseMatrixRow row1 = featureMatrix.getRow(pageId1);
SparseMatrixRow row2 = featureMatrix.getRow(pageId2);
if (row1 == null || row2 == null) {
return null;
} else {
if (featureFilter != null) {
row1 = featureFilter.filter(pageId1, row1);
row2 = featureFilter.filter(pageId2, row2);
}
SRResult result= new SRResult(similarity.similarity(row1, row2));
if(explanations) {
TIntFloatHashMap tfm1=row1.asTroveMap();
TIntFloatHashMap tfm2=row2.asTroveMap();
result.setExplanations(generator.getExplanations(pageId1, pageId2, tfm1, tfm2, result));
}
return normalize(result);
}
} else {
// feature filter gets applied in getPageVector if necessary
TIntFloatMap vector1 = getPageVector(pageId1);
TIntFloatMap vector2 = getPageVector(pageId2);
if (vector1 == null || vector2 == null) {
return null;
}
return normalize(new SRResult(similarity.similarity(vector1, vector2)));
}
} catch (IOException e) {
throw new DaoException(e);
}
}
@Override
public SRResultList mostSimilar(String phrase, int maxResults, TIntSet validIds) throws DaoException {
if (featureFilter != null) {
throw new UnsupportedOperationException();
}
TIntFloatMap vector = null;
// try using phrases directly
try {
vector = generator.getVector(phrase);
} catch (UnsupportedOperationException e) {
// try using other methods
}
if (vector == null) {
// fall back on parent's phrase resolution algorithm
return super.mostSimilar(phrase, maxResults, validIds);
} else {
try {
return normalize(similarity.mostSimilar(vector, maxResults, validIds));
} catch (IOException e) {
throw new DaoException(e);
}
}
}
@Override
public SRResultList mostSimilar(int pageId, int maxResults, TIntSet validIds) throws DaoException {
if (featureFilter != null) {
throw new UnsupportedOperationException();
}
try {
TIntFloatMap vector = getPageVector(pageId);
if (vector == null) return null;
return normalize(similarity.mostSimilar(vector, maxResults, validIds));
} catch (IOException e) {
throw new DaoException(e);
}
}
/**
* Train the similarity() function.
* The KnownSims may already be associated with Wikipedia ids (check wpId1 and wpId2).
*
* @param dataset A gold standard dataset
*/
@Override
public void trainSimilarity(Dataset dataset) throws DaoException {
super.trainSimilarity(dataset); // DO nothing, for now.
}
/**
* @see org.wikibrain.sr.SRMetric#trainMostSimilar(org.wikibrain.sr.dataset.Dataset, int, gnu.trove.set.TIntSet)
*/
@Override
public void trainMostSimilar(Dataset dataset, int numResults, TIntSet validIds) {
try {
buildFeatureAndTransposeMatrices(validIds);
super.trainMostSimilar(dataset, numResults, validIds);
} catch (IOException e) {
LOG.error("training failed", e);
throw new RuntimeException(e); // somewhat unexpected...
}
}
@Override
public double[][] cosimilarity(int pageIds[]) throws DaoException {
return cosimilarity(pageIds, pageIds);
}
@Override
public double[][] cosimilarity(String phrases[]) throws DaoException {
return cosimilarity(phrases, phrases);
}
/**
* Calculates the cosimilarity matrix between phrases.
* First tries to use generator to get phrase vectors directly, but some generators will not support this.
* Falls back on disambiguating phrase vectors to page ids.
*
* @param rowPhrases
* @param colPhrases
* @return
* @throws DaoException
*/
@Override
public double[][] cosimilarity(String rowPhrases[], String colPhrases[]) throws DaoException {
if (featureFilter != null) {
throw new UnsupportedOperationException();
}
if (rowPhrases.length == 0 || colPhrases.length == 0) {
return new double[rowPhrases.length][colPhrases.length];
}
List<TIntFloatMap> rowVectors = new ArrayList<TIntFloatMap>();
List<TIntFloatMap> colVectors = new ArrayList<TIntFloatMap>();
try {
// Try to use strings directly, but generator may not support them, so fall back on disambiguation
Map<String, TIntFloatMap> vectors = new HashMap<String, TIntFloatMap>();
for (String s : ArrayUtils.addAll(rowPhrases, colPhrases)) {
if (!vectors.containsKey(s)) {
vectors.put(s, generator.getVector(s));
}
}
for (String s : rowPhrases) {
rowVectors.add(vectors.get(s));
}
for (String s : colPhrases) {
colVectors.add(vectors.get(s));
}
} catch (UnsupportedOperationException e) {
}
if (rowVectors.isEmpty() || colVectors.isEmpty()) {
return super.cosimilarity(rowPhrases, colPhrases);
} else {
return cosimilarity(rowVectors, colVectors);
}
}
/**
* Computes the cosimilarity matrix between pages.
* @param rowIds
* @param colIds
* @return
* @throws DaoException
*/
@Override
public double[][] cosimilarity(int rowIds[], int colIds[]) throws DaoException {
if (hasFeatureMatrix()) {
// special optimized case
TIntObjectMap<SparseMatrixRow> vectors = new TIntObjectHashMap<SparseMatrixRow>(rowIds.length + colIds.length);
for (int id : ArrayUtils.addAll(rowIds, colIds)) {
if (vectors.containsKey(id)) {
continue;
}
SparseMatrixRow row = null;
try {
row = featureMatrix.getRow(id);
} catch (IOException e) {
throw new DaoException(e);
}
if (row != null) {
if (featureFilter != null) {
row = featureFilter.filter(id, row);
}
vectors.put(id, row);
}
}
double results[][] = new double[rowIds.length][colIds.length];
for (int i = 0; i < rowIds.length; i++) {
SparseMatrixRow row1 = vectors.get(rowIds[i]);
if (row1 != null) {
for (int j = 0; j < colIds.length; j++) {
SparseMatrixRow row2 = vectors.get(colIds[j]);
if (row2 != null) {
results[i][j] = normalize(similarity.similarity(row1, row2));
}
}
}
}
return results;
} else {
// Build up vectors for unique pages
Map<Integer, TIntFloatMap> vectors = new HashMap<Integer, TIntFloatMap>();
for (int pageId : ArrayUtils.addAll(colIds, rowIds)) {
if (!vectors.containsKey(pageId)) {
try {
vectors.put(pageId, getPageVector(pageId));
} catch (IOException e) {
throw new DaoException(e);
}
}
}
List<TIntFloatMap> rowVectors = new ArrayList<TIntFloatMap>();
for (int rowId : rowIds) {
rowVectors.add(vectors.get(rowId));
}
List<TIntFloatMap> colVectors = new ArrayList<TIntFloatMap>();
for (int colId : colIds) {
colVectors.add(vectors.get(colId));
}
return cosimilarity(rowVectors, colVectors);
}
}
/**
* Computes the cosimilarity between a set of vectors.
* @param rowVectors
* @param colVectors
* @return
*/
protected double[][] cosimilarity(List<TIntFloatMap> rowVectors, List<TIntFloatMap> colVectors) {
if (featureFilter != null) {
throw new UnsupportedOperationException();
}
double results[][] = new double[rowVectors.size()][colVectors.size()];
for (int i = 0; i < rowVectors.size(); i++) {
for (int j = 0; j < colVectors.size(); j++) {
TIntFloatMap vi = rowVectors.get(i);
TIntFloatMap vj = colVectors.get(j);
results[i][j] = normalize(similarity.similarity(vi, vj));
}
}
return results;
}
/**
* Rebuild the feature and transpose matrices.
* If the matrices are available from the feature generator, they will be used.
* If not, they will be regenerated.
* @param validIds
* @throws IOException
*/
public synchronized void buildFeatureAndTransposeMatrices(TIntSet validIds) throws IOException {
if (validIds == null) {
validIds = getAllPageIds();
}
IOUtils.closeQuietly(featureMatrix);
IOUtils.closeQuietly(transposeMatrix);
featureMatrix = null;
transposeMatrix = null;
getDataDir().mkdirs();
ValueConf vconf = new ValueConf((float)similarity.getMinValue(),
(float)similarity.getMaxValue());
final SparseMatrixWriter writer = new SparseMatrixWriter(getFeatureMatrixPath(), vconf);
ParallelForEach.loop(
WbArrayUtils.toList(validIds.toArray()),
WpThreadUtils.getMaxThreads(),
new Procedure<Integer>() {
public void call(Integer pageId) throws IOException {
TIntFloatMap scores = getPageVector(pageId);
if (scores != null && !scores.isEmpty()) {
writer.writeRow(new SparseMatrixRow(writer.getValueConf(), pageId, scores));
}
}
}, 10000);
writer.finish();
// Reload the feature matrix
featureMatrix = new SparseMatrix(getFeatureMatrixPath());
getDataDir().mkdirs();
new SparseMatrixTransposer(featureMatrix, getTransposeMatrixPath())
.transpose();
transposeMatrix = new SparseMatrix(getTransposeMatrixPath());
similarity.setMatrices(featureMatrix, transposeMatrix, getDataDir());
}
private TIntSet getAllPageIds() throws IOException {
TIntSet validIds;DaoFilter filter = new DaoFilter()
.setLanguages(getLanguage())
.setDisambig(false)
.setRedirect(false)
.setNameSpaces(NameSpace.ARTICLE);
validIds = new TIntHashSet();
try {
for (LocalPage page : (Iterable<LocalPage>)getLocalPageDao().get(filter)) {
validIds.add(page.getLocalId());
}
} catch (DaoException e) {
throw new IOException(e);
}
return validIds;
}
protected File getFeatureMatrixPath() {
return new File(getDataDir(), "feature.matrix");
}
protected File getTransposeMatrixPath() {
return new File(getDataDir(), "featureTranspose.matrix");
}
@Override
public void read() throws IOException {
super.read();
if (getFeatureMatrixPath().isFile() && getTransposeMatrixPath().isFile()) {
IOUtils.closeQuietly(featureMatrix);
IOUtils.closeQuietly(transposeMatrix);
featureMatrix = new SparseMatrix(getFeatureMatrixPath());
transposeMatrix = new SparseMatrix(getTransposeMatrixPath());
similarity.setMatrices(featureMatrix, transposeMatrix, getDataDir());
}
}
/**
* Returns the vector associated with a page, or null.
* @param pageId
* @return
*/
public TIntFloatMap getPageVector(int pageId) throws IOException {
if (hasFeatureMatrix()) {
SparseMatrixRow row = featureMatrix.getRow(pageId);
if (row == null) {
return null;
} else if (featureFilter != null) {
return featureFilter.filter(pageId, row.asTroveMap());
} else {
return row.asTroveMap();
}
} else {
try {
if (featureFilter != null) {
return featureFilter.filter(pageId, generator.getVector(pageId));
} else {
return generator.getVector(pageId);
}
} catch (DaoException e) {
throw new IOException(e);
}
}
}
protected boolean hasFeatureMatrix() {
return featureMatrix != null && featureMatrix.getNumRows() > 0;
}
protected boolean hasTransposeMatrix() {
return transposeMatrix != null && transposeMatrix.getNumRows() > 0;
}
public SparseVectorGenerator getGenerator() {
return generator;
}
public VectorSimilarity getSimilarity() {
return similarity;
}
public void setFeatureFilter(FeatureFilter filter) {
this.featureFilter = filter;
}
@Override
public SRConfig getConfig() {
return config;
}
public static class Provider extends org.wikibrain.conf.Provider<SRMetric> {
public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
super(configurator, config);
}
@Override
public Class getType() {
return SRMetric.class;
}
@Override
public String getPath() {
return "sr.metric.local";
}
@Override
public SRMetric get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException {
if (!config.getString("type").equals("sparsevector")) {
return null;
}
if (runtimeParams == null || !runtimeParams.containsKey("language")){
throw new IllegalArgumentException("Monolingual requires 'language' runtime parameter.");
}
Language language = Language.getByLangCode(runtimeParams.get("language"));
Map<String, String> params = new HashMap<String, String>();
params.put("language", language.getLangCode());
SparseVectorGenerator generator = getConfigurator().construct(
SparseVectorGenerator.class, null, config.getConfig("generator"), params);
VectorSimilarity similarity = getConfigurator().construct(
VectorSimilarity.class, null, config.getConfig("similarity"), params);
SparseVectorSRMetric sr = new SparseVectorSRMetric(
name,
language,
getConfigurator().get(LocalPageDao.class,config.getString("pageDao")),
getConfigurator().get(Disambiguator.class,config.getString("disambiguator"),"language", language.getLangCode()),
generator,
similarity
);
configureBase(getConfigurator(), sr, config);
return sr;
}
}
}