package org.wikibrain.sr.vector;
import com.typesafe.config.Config;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.set.TIntSet;
import org.apache.commons.lang3.ArrayUtils;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.disambig.Disambiguator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* @author Shilad Sen
*/
public class FancyPhraseVectorBasedSRMetric extends SparseVectorSRMetric {
private static enum PhraseMode {
GENERATOR, // try to get phrase vectors from the generator directly
CREATOR, // try to get phrase vectors form the phrase vector creator
BOTH, // first try the generator, then the creator
NONE // don't resolve phrases at all.
}
private final PhraseVectorCreator phraseVectorCreator;
private PhraseMode phraseMode = PhraseMode.BOTH;
public FancyPhraseVectorBasedSRMetric(String name, Language language, LocalPageDao dao, Disambiguator disambig, SparseVectorGenerator generator, VectorSimilarity similarity, PhraseVectorCreator creator) {
super(name, language, dao, disambig, generator, similarity);
this.phraseVectorCreator = creator;
creator.setMetric(this);
}
@Override
public SRResult similarity(String phrase1, String phrase2, boolean explanations) throws DaoException {
if (phraseMode == PhraseMode.NONE) {
return super.similarity(phrase1, phrase2, explanations);
}
TIntFloatMap vector1 = null;
TIntFloatMap vector2 = null;
// try using phrases directly
if (phraseMode == PhraseMode.BOTH || phraseMode == PhraseMode.GENERATOR) {
try {
vector1 = generator.getVector(phrase1);
vector2 = generator.getVector(phrase2);
} catch (UnsupportedOperationException e) {
// try using other methods
}
}
if ((vector1 == null || vector2 == null)
&& (phraseMode == PhraseMode.BOTH || phraseMode == PhraseMode.CREATOR)) {
if (phraseVectorCreator == null) {
throw new IllegalStateException("phraseMode is " + phraseMode + " but phraseVectorCreator is null");
}
TIntFloatMap vectors[] = phraseVectorCreator.getPhraseVectors(phrase1, phrase2);
if (vectors != null) {
vector1 = vectors[0];
vector2 = vectors[1];
}
}
if (vector1 == null || vector2 == null) {
// fallback on parent's phrase resolution algorithm
return super.similarity(phrase1, phrase2, explanations);
} else {
SRResult result= new SRResult(similarity.similarity(vector1, vector2));
if(explanations) {
result.setExplanations(generator.getExplanations(phrase1, phrase2, vector1, vector2, result));
}
return normalize(result);
}
}
@Override
public SRResultList mostSimilar(String phrase, int maxResults, TIntSet validIds) throws DaoException {
if (phraseMode == PhraseMode.NONE) {
return super.mostSimilar(phrase, maxResults, validIds);
}
TIntFloatMap vector = null;
// try using phrases directly
if (phraseMode == PhraseMode.BOTH || phraseMode == PhraseMode.GENERATOR) {
try {
vector = generator.getVector(phrase);
} catch (UnsupportedOperationException e) {
// try using other methods
}
}
if (vector == null && (phraseMode == PhraseMode.BOTH || phraseMode == PhraseMode.CREATOR)) {
if (phraseVectorCreator == null) {
throw new IllegalStateException("phraseMode is " + phraseMode + " but phraseVectorCreator is null");
}
vector = phraseVectorCreator.getPhraseVector(phrase);
}
if (vector == null) {
// fall back on parent's phrase resolution algorithm
return super.mostSimilar(phrase, maxResults, validIds);
} else {
try {
return similarity.mostSimilar(vector, maxResults, validIds);
} catch (IOException e) {
throw new DaoException(e);
}
}
}
/**
* Calculates the cosimilarity matrix between phrases.
* First tries to use generator to get phrase vectors directly, but some generators will not support this.
* Falls back on disambiguating phrase vectors to page ids.
*
* @param rowPhrases
* @param colPhrases
* @return
* @throws DaoException
*/
@Override
public double[][] cosimilarity(String rowPhrases[], String colPhrases[]) throws DaoException {
if (rowPhrases.length == 0 || colPhrases.length == 0) {
return new double[rowPhrases.length][colPhrases.length];
}
List<TIntFloatMap> rowVectors = new ArrayList<TIntFloatMap>();
List<TIntFloatMap> colVectors = new ArrayList<TIntFloatMap>();
try {
// Try to use strings directly, but generator may not support them, so fall back on disambiguation
Map<String, TIntFloatMap> vectors = new HashMap<String, TIntFloatMap>();
for (String s : ArrayUtils.addAll(rowPhrases, colPhrases)) {
if (!vectors.containsKey(s)) {
vectors.put(s, generator.getVector(s));
}
}
for (String s : rowPhrases) {
rowVectors.add(vectors.get(s));
}
for (String s : colPhrases) {
colVectors.add(vectors.get(s));
}
} catch (UnsupportedOperationException e) {
}
// If direct phrase vectors failed, try to disambiguate
if (rowVectors.isEmpty() || colVectors.isEmpty()) {
List<String> unique = new ArrayList<String>();
for (String s : ArrayUtils.addAll(rowPhrases, colPhrases)) {
if (!unique.contains(s)) {
unique.add(s);
}
}
TIntFloatMap[] vectors = phraseVectorCreator.getPhraseVectors(unique.toArray(new String[0]));
for (String s : rowPhrases) {
int i = unique.indexOf(s);
if (i < 0) throw new IllegalStateException();
rowVectors.add(vectors[i]);
}
for (String s : colPhrases) {
int i = unique.indexOf(s);
if (i < 0) throw new IllegalStateException();
colVectors.add(vectors[i]);
}
}
return cosimilarity(rowVectors, colVectors);
}
public void setPhraseMode(PhraseMode mode) {
this.phraseMode = mode;
}
public static class Provider extends org.wikibrain.conf.Provider<SRMetric> {
public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
super(configurator, config);
}
@Override
public Class getType() {
return SRMetric.class;
}
@Override
public String getPath() {
return "sr.metric.local";
}
@Override
public SRMetric get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException {
if (!config.getString("type").equals("fancyphrasevector")) {
return null;
}
if (runtimeParams == null || !runtimeParams.containsKey("language")){
throw new IllegalArgumentException("Monolingual requires 'language' runtime parameter.");
}
Language language = Language.getByLangCode(runtimeParams.get("language"));
Map<String, String> params = new HashMap<String, String>();
params.put("language", language.getLangCode());
SparseVectorGenerator generator = getConfigurator().construct(
SparseVectorGenerator.class, null, config.getConfig("generator"), params);
VectorSimilarity similarity = getConfigurator().construct(
VectorSimilarity.class, null, config.getConfig("similarity"), params);
FancyPhraseVectorBasedSRMetric sr = new FancyPhraseVectorBasedSRMetric(
name,
language,
getConfigurator().get(LocalPageDao.class,config.getString("pageDao")),
getConfigurator().get(Disambiguator.class,config.getString("disambiguator"),"language", language.getLangCode()),
generator,
similarity,
getConfigurator().construct(
PhraseVectorCreator.class, null, config.getConfig("phrases"), null)
);
if (config.hasPath("phraseMode")) {
sr.setPhraseMode(PhraseMode.valueOf(config.getString("phraseMode").toUpperCase()));
}
configureBase(getConfigurator(), sr, config);
return sr;
}
}
}