package org.wikibrain.sr.phrasesim;
import com.typesafe.config.Config;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.TLongFloatMap;
import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.map.hash.TIntFloatHashMap;
import gnu.trove.procedure.TIntFloatProcedure;
import gnu.trove.procedure.TLongFloatProcedure;
import gnu.trove.set.TIntSet;
import org.apache.commons.io.FileUtils;
import org.mapdb.DB;
import org.mapdb.DBMaker;
import org.mapdb.HTreeMap;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.lang.StringNormalizer;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.normalize.IdentityNormalizer;
import org.wikibrain.sr.normalize.Normalizer;
import org.wikibrain.sr.normalize.PercentileNormalizer;
import org.wikibrain.sr.utils.Leaderboard;
import org.wikibrain.sr.vector.SparseVectorSRMetric;
import org.wikibrain.utils.ParallelForEach;
import org.wikibrain.utils.Procedure;
import org.wikibrain.utils.WpIOUtils;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This metric is intended to support very fast SR for a known (expandable) set of phrases.
* This metric ONLY operates on phrases, not on pages (TODO: fix this).
*
* The SR metric must be made aware of phrases using the addPhrase() method.
* All phrases have an application-generated integer ID associated with them.
* "Unkonwn" phrases (e.g. those not added) will not be returned by mostSimilar() and
* will fail for, e.g. similarity() and cosimilarity().
*
* SR methods (e.g. similarity, cosimilarity, mostSimilar) that typically take or return
* local article ids will instead take or return phrase ids. For example, mostSimilar()
* returns the scores and ids associated with known phrases.
*
* Phrase are represented as vectors, and cached in two methods. First, a full cosimilarity
* matrix is maintained, ensuring that any all SR methods on existing phrases are fast.
* Second, an inverted index for the vector representations is maintained so that
* all cosimilarities for a new phrase can be calculated very quickly.
*
* The universe of known phrases and associated data structures is serialized dynamically
* to files in the specific data directory. However, the full cosimilarity matrix is only
* written out when the write() method (or flushCosimilarity method) is called.
*
* The normalizer should be retrained for internal phrases (using trainNormalizer())
* periodically. It initially defaults to the "identity" normalizer.
*
* This means that space and time complexities are O(N^2) for a new phrase - to be specific,
* 4 bytes are needed for each element in the cosimilarity matrix. In exchange for this,
* similarity is O(1), cosimilarity is O(m*n) for m phrases by n phrases, and mostSimilar
* is O(n). The complexity of addPhrase scales with the sparsity of the feature vector matrix.
* To be specific, the complexity of addPhrase(phrase) is linear in the number of non-zero
* cells in the full (all-phrase) feature matrix for each of the phrase's features.
*
* All elements of this metric are thread-safe.
*
* @author Shilad Sen
*/
public class KnownPhraseSim implements SRMetric {
private static final Logger LOGGER = LoggerFactory.getLogger(KnownPhraseSim.class);
private final StringNormalizer stringNormalizer;
private final HTreeMap<Object, Object> db;
private final PhraseCreator creator;
private final Language language;
private final File dir;
private final String name;
private Normalizer scoreNormalizer = new IdentityNormalizer();
// Regular index
private ConcurrentHashMap<String, KnownPhrase> byPhrase;
private ConcurrentHashMap<Integer, KnownPhrase> byId;
// Keeps an inverted index for fast mostSimilar performance
private ConcurrentHashMap<Long, TIntFloatMap> invertedIndex = new ConcurrentHashMap<Long, TIntFloatMap>();
// Caches the full cosimilarity matrix.
private CosimilarityMatrix cosim = new CosimilarityMatrix();
private DB phraseDb;
public KnownPhraseSim(Language language, PhraseCreator creator, File dir, StringNormalizer stringNormalizer) throws IOException {
this("known-phrase-sim", language, creator, dir, stringNormalizer);
}
public KnownPhraseSim(String name, Language language, PhraseCreator creator, File dir, StringNormalizer stringNormalizer) throws IOException {
this.name = name;
this.language = language;
this.creator = creator;
this.stringNormalizer = stringNormalizer;
this.dir = dir;
this.dir.mkdirs();
this.phraseDb = DBMaker
.newFileDB(new File(dir, "phrases.mapdb"))
.mmapFileEnable()
.transactionDisable()
.asyncWriteEnable()
.asyncWriteFlushDelay(100)
.make();
this.db = phraseDb.getHashMap("phrases");
this.readPhrases();
this.readCosimilarity();
File f = new File(dir, "scoreNormalizer.bin");
if (f.isFile()) {
scoreNormalizer = (Normalizer) WpIOUtils.readObjectFromFile(f);
}
}
@Override
public void read() {
throw new UnsupportedOperationException("Metric cannot be re-read after creation");
}
/**
* Write simply flushes the cache. All other writes happen asynchronously.
* @throws IOException
*/
@Override
public void write() throws IOException {
flushCosimilarity();
}
private void readCosimilarity() throws IOException {
File f = new File(dir, "cosimilarity.bin");
try {
this.cosim = (CosimilarityMatrix) WpIOUtils.readObjectFromFile(f);
} catch (Exception e) {
LOGGER.info("Reading cosim file " + f + " failed... rebuilding it from scratch");
cosim = new CosimilarityMatrix();
}
final TIntSet built = cosim.getCompleted();
ParallelForEach.loop(byId.values(), new Procedure<KnownPhrase>() {
@Override
public void call(KnownPhrase p) throws Exception {
if (!built.contains(p.getId())) {
SRResultList neighbors = indexedMostSimilar(p.getVector(), byId.size(), null);
cosim.update(p.getId(), neighbors);
}
}
});
}
public void flushCosimilarity() throws IOException {
WpIOUtils.writeObjectToFile(new File(dir, "cosimilarity.bin"), cosim);
db.getEngine().commit();
}
private void readPhrases() {
byId = new ConcurrentHashMap<Integer, KnownPhrase>();
byPhrase = new ConcurrentHashMap<String, KnownPhrase>();
for (Map.Entry entry : db.entrySet()) {
String key = (String) entry.getKey();
KnownPhrase val = (KnownPhrase) entry.getValue();
if (!key.equals(val.getNormalizedPhrase())) {
throw new IllegalStateException();
}
byId.put(val.getId(), val);
byPhrase.put(val.getNormalizedPhrase(), val);
for (String version : val.getVersions()) {
byPhrase.put(version, val);
}
PhraseVector v = val.getVector();
for (int i = 0; i < v.ids.length; i++) {
long featureId = v.ids[i];
float featureVal = v.vals[i];
invertedIndex.putIfAbsent(featureId, new TIntFloatHashMap());
TIntFloatMap index = invertedIndex.get(featureId);
synchronized (index) {
invertedIndex.get(featureId).put(val.getId(), featureVal);
}
}
}
}
/**
* Adds a particular phrase to the internal SR model.
* Multiple calls to add() for the same phrase are safe
* (the phrase's frequency will be incremented).
*
* @param phrase
* @param id
*/
public void addPhrase(String phrase, final int id) {
KnownPhrase ifAbsent = new KnownPhrase(id, phrase, normalize(phrase));
KnownPhrase old = byPhrase.putIfAbsent(ifAbsent.getNormalizedPhrase(), ifAbsent);
if (old == null) {
TLongFloatMap vector = creator.getVector(phrase);
if (vector == null) {
return;
}
ifAbsent.setVector(new PhraseVector(vector));
byId.put(id, ifAbsent);
db.put(ifAbsent.getNormalizedPhrase(), new KnownPhrase(ifAbsent));
vector.forEachEntry(new TLongFloatProcedure() {
@Override
public boolean execute(long k, float v) {
invertedIndex.putIfAbsent(k, new TIntFloatHashMap());
TIntFloatMap index = invertedIndex.get(k);
synchronized (index) {
invertedIndex.get(k).put(id, v);
}
return true;
}
});
if (cosim != null) {
SRResultList neighbors = indexedMostSimilar(ifAbsent.getVector(), byId.size(), null);
cosim.update(id, neighbors);
}
} else {
old.increment(phrase);
db.put(ifAbsent.getNormalizedPhrase(), new KnownPhrase(old));
}
}
public void rebuild() {
throw new UnsupportedOperationException();
}
/**
* Trains the normalizer on the existing phrases.
* The normalizer is (right now) ALWAYS a
* percentile normalizer to the power of 10.
*
* @throws IOException
*/
public void trainNormalizer() throws IOException {
Normalizer restored = this.scoreNormalizer;
this.scoreNormalizer = new IdentityNormalizer();
try {
List<Integer> ids = new ArrayList<Integer>(byId.keySet());
Random random = new Random();
PercentileNormalizer newNormalizer = new PercentileNormalizer();
newNormalizer.setPower(10);
newNormalizer.setSampleSize(100000);
for (int i = 0; i < 1000; i++) {
int id = ids.get(random.nextInt(ids.size()));
for (SRResult r : mostSimilar(id, ids.size())) {
newNormalizer.observe(r.getScore());
}
}
newNormalizer.observationsFinished();
File f = new File(dir, "scoreNormalizer.bin");
WpIOUtils.writeObjectToFile(f, newNormalizer);
restored = newNormalizer;
} finally {
this.scoreNormalizer = restored;
}
}
/**
* Returns the phrase associated with a particular id (or null).
* @param id
* @return
*/
public String getPhrase(int id) {
if (byId.containsKey(id)) {
return byId.get(id).getCanonicalVersion();
}
return null;
}
/**
*
* @param phrase
* @return
*/
public Integer getId(String phrase) {
KnownPhrase kp = byPhrase.get(normalize(phrase));
if (kp == null) {
return null;
} else {
return kp.getId();
}
}
/**
* Return the normalized (i.e. canonical) string associated with a phrase.
* @param phrase
* @return
*/
public String normalize(String phrase) {
return stringNormalizer.normalize(language, phrase);
}
@Override
public SRResult similarity(int id1, int id2, boolean explanations) {
return new SRResult(cosim.similarity(id1, id2));
}
@Override
public SRResult similarity(String phrase1, String phrase2, boolean explanations) throws DaoException {
Integer id1 = getId(phrase1);
Integer id2 = getId(phrase2);
if (id1 == null || id2 == null) {
return new SRResult(Double.NaN);
}
return similarity(id1, id2, explanations);
}
@Override
public SRResultList mostSimilar(String phrase, int maxResults, TIntSet candidateIds) {
Integer id = getId(phrase);
if (id == null) {
return null;
} else {
return mostSimilar(id, maxResults, candidateIds);
}
}
@Override
public SRResultList mostSimilar(String phrase, int maxResults) {
return mostSimilar(phrase, maxResults, null);
}
@Override
public SRResultList mostSimilar(int id, int maxResults, TIntSet candidateIds) {
KnownPhrase p = byId.get(id);
if (p == null) {
return null;
}
SRResultList results;
if (cosim == null) {
PhraseVector v1 = p.getVector();
if (candidateIds != null && candidateIds.size() < 10) {
return mostSimilar(v1, maxResults, candidateIds);
} else {
return indexedMostSimilar(v1, maxResults, candidateIds);
}
} else {
results = cosim.mostSimilar(id, maxResults, candidateIds);
}
return scoreNormalizer.normalize(results);
}
private SRResultList indexedMostSimilar(PhraseVector v1, int maxResults, TIntSet candidateIds) {
final TIntDoubleHashMap dots = new TIntDoubleHashMap(maxResults * 5);
for (int i = 0; i < v1.ids.length; i++) {
long featureId = v1.ids[i];
final float featureVal = v1.vals[i];
TIntFloatMap index = invertedIndex.get(featureId);
if (index == null) continue;
synchronized (index) {
index.forEachEntry(new TIntFloatProcedure() {
@Override
public boolean execute(int id, float val) {
dots.adjustOrPutValue(id, val * featureVal, val * featureVal);
return true;
}
});
}
}
final Leaderboard leaderboard = new Leaderboard(maxResults);
double l1 = v1.norm2();
int keys[] = dots.keys();
for (int i = 0; i < keys.length; i++) {
int id = keys[i];
double l2 = byId.get(id).getVector().norm2();
double dot = dots.get(id);
double sim = dot / (l1 * l2);
leaderboard.tallyScore(id, sim);
}
return leaderboard.getTop();
}
private SRResultList mostSimilar(PhraseVector v1, int maxResults, TIntSet candidateIds) {
final Leaderboard top = new Leaderboard(maxResults);
if (candidateIds != null) {
for (int id2 : candidateIds.toArray()) {
KnownPhrase p2 = byId.get(id2);
if (p2 != null) {
top.tallyScore(id2, v1.cosineSim(p2.getVector()));
}
}
} else {
for (KnownPhrase p2 : byId.values()) {
top.tallyScore(p2.getId(), v1.cosineSim(p2.getVector()));
}
}
return top.getTop();
}
@Override
public SRResultList mostSimilar(int id, int maxResults) {
return mostSimilar(id, maxResults, null);
}
@Override
public double[][] cosimilarity(String rows[], String columns[]) {
int rowIds[] = new int[rows.length];
for (int i = 0; i < rowIds.length; i++) {
rowIds[i] = getId(rows[i]);
}
int colIds[] = new int[columns.length];
for (int i = 0; i < colIds.length; i++) {
colIds[i] = getId(columns[i]);
}
return cosimilarity(rowIds, colIds);
}
@Override
public double[][] cosimilarity(int[] ids) throws DaoException {
return new double[0][];
}
@Override
public double[][] cosimilarity(String[] phrases) throws DaoException {
return new double[0][];
}
public float[] getPhraseVector(String phrase) {
Integer id = getId(phrase);
if (id == null) {
return null;
} else {
return getPhraseVector(id);
}
}
public float[] getPhraseVector(int id) {
return cosim.getVector(id);
}
@Override
public double[][] cosimilarity(int rows[], int columns[]) {
double cosims[][] = new double[rows.length][columns.length];
if (cosim != null) {
return cosim.cosimilarity(rows, columns);
}
List<PhraseVector> colVectors = new ArrayList<PhraseVector>(columns.length);
for (int i = 0; i < columns.length; i++) {
KnownPhrase kp = byId.get(columns[i]);
colVectors.add(kp == null ? null : kp.getVector());
}
for (int i = 0; i < rows.length; i++) {
KnownPhrase kp = byId.get(columns[i]);
if (kp == null) {
continue; // leave sims as their default value of 0.0
}
PhraseVector v1 = kp.getVector();
for (int j = 0; j < columns.length; j++) {
PhraseVector v2 = colVectors.get(j);
if (v2 != null) {
cosims[i][j] = scoreNormalizer.normalize(v1.cosineSim(v2));
}
}
}
return cosims;
}
@Override public String getName() { return name; }
@Override public Language getLanguage() { return language; }
@Override public File getDataDir() { return dir; }
@Override public void setDataDir(File dir) { throw new UnsupportedOperationException(); }
public Normalizer getScoreNormalizer() {
return scoreNormalizer;
}
@Override public Normalizer getMostSimilarNormalizer() { return scoreNormalizer; }
@Override public void setMostSimilarNormalizer(Normalizer n) { throw new UnsupportedOperationException(); }
@Override public Normalizer getSimilarityNormalizer() { return scoreNormalizer; }
@Override public void setSimilarityNormalizer(Normalizer n) { throw new UnsupportedOperationException(); }
@Override public void trainSimilarity(Dataset dataset) {}
@Override public void trainMostSimilar(Dataset dataset, int numResults, TIntSet validIds) {}
@Override public boolean similarityIsTrained() { return false; }
@Override public boolean mostSimilarIsTrained() { return false; }
public static class Provider extends org.wikibrain.conf.Provider<SRMetric> {
public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
super(configurator, config);
}
@Override
public Class getType() {
return SRMetric.class;
}
@Override
public String getPath() {
return "sr.metric.local";
}
@Override
public SRMetric get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException {
if (!config.getString("type").equals("knownphrase")) {
return null;
}
if (runtimeParams == null || !runtimeParams.containsKey("language")) {
throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter");
}
Language language = Language.getByLangCode(runtimeParams.get("language"));
List<String> names = config.getStringList("metrics");
SparseVectorSRMetric[] metrics = new SparseVectorSRMetric[names.size()];
for (int i = 0; i < names.size(); i++) {
metrics[i] = (SparseVectorSRMetric) getConfigurator().get(
SRMetric.class, names.get(i),
"language", language.getLangCode());
}
PhraseCreator creator = new EnsemblePhraseCreator(
metrics,
toPrimitive(config.getDoubleList("coefficients")));
String stringNormalizerName = null;
if (config.hasPath("stringnormalizer")) {
stringNormalizerName = config.getString("stringnormalizer");
}
StringNormalizer normalizer = getConfigurator().get(StringNormalizer.class, stringNormalizerName);
File dir = FileUtils.getFile(
getConfig().getString("sr.metric.path"),
name,
language.getLangCode());
try {
return new KnownPhraseSim(name, language, creator, dir, normalizer);
} catch (IOException e) {
throw new ConfigurationException(e);
}
}
}
private static double[] toPrimitive(List<Double> l) {
double [] result = new double[l.size()];
for (int i = 0; i < l.size(); i++) {
result[i] = l.get(i);
}
return result;
}
}