package org.wikibrain.sr.ensemble;
import com.typesafe.config.Config;
import gnu.trove.map.TIntDoubleMap;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.procedure.TIntDoubleProcedure;
import gnu.trove.procedure.TIntProcedure;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.lang.Language;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.normalize.Normalizer;
import org.wikibrain.sr.utils.Leaderboard;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
*
* A simple linear ensemble that requires no training or configuration.
*
* Since this metric is zero-training / zero-configuration, all the training,
* normalization, and I/O methods are trivial no-op implementations.
*
* @author Shilad Sen
*/
public class SimpleEnsembleMetric implements SRMetric {
private class SubMetric {
SRMetric metric;
double coefficient;
}
private static final Logger LOG = LoggerFactory.getLogger(SimpleEnsembleMetric.class);
private final String name;
private final Language language;
private SubMetric metrics[];
private boolean trainSubmetrics = true;
private double numCandidateMultiplier = 2.0;
public SimpleEnsembleMetric(String name, Language language, List<SRMetric> metrics, List<Double> coefficients){
if (metrics.size() != coefficients.size()) {
throw new IllegalArgumentException();
}
if (metrics.isEmpty()) {
throw new IllegalArgumentException("Must supply at least one metric to the simple ensemble.");
}
this.metrics=new SubMetric[metrics.size()];
for (int i =0 ; i < metrics.size(); i++) {
this.metrics[i] = new SubMetric();
this.metrics[i].metric = metrics.get(i);
this.metrics[i].coefficient = coefficients.get(i);
}
this.name = name;
this.language = language;
}
@Override
public String getName() {
return name;
}
@Override
public Language getLanguage() {
return language;
}
@Override
public SRResult similarity(int pageId1, int pageId2, boolean explanations) throws DaoException {
// TODO: Handle explanations
double sum = 0.0;
double divisor = 0.0;
for (SubMetric m : metrics) {
SRResult r = m.metric.similarity(pageId1, pageId2, false);
if (r != null && r.isValid()) {
sum += m.coefficient * r.getScore();
divisor += m.coefficient;
}
}
return new SRResult((divisor > 0) ? (sum / divisor) : Double.NaN);
}
@Override
public SRResult similarity(String phrase1, String phrase2, boolean explanations) throws DaoException {
double sum = 0.0;
double divisor = 0.0;
for (SubMetric m : metrics) {
SRResult r = m.metric.similarity(phrase1, phrase2, false);
if (r != null && r.isValid()) {
sum += m.coefficient * r.getScore();
divisor += m.coefficient;
}
}
return new SRResult((divisor > 0) ? (sum / divisor) : Double.NaN);
}
@Override
public SRResultList mostSimilar(int pageId, int maxResults) throws DaoException {
return mostSimilar(pageId, maxResults, null);
}
@Override
public SRResultList mostSimilar(int pageId, int maxResults, TIntSet validIds) throws DaoException {
TIntSet candidateSet = new TIntHashSet();
for (SubMetric m : metrics) {
SRResultList rl = m.metric.mostSimilar(pageId, (int)(maxResults * numCandidateMultiplier), validIds);
if (rl != null) {
for (SRResult r : rl) {
candidateSet.add(r.getId());
}
}
}
int candidates[] = candidateSet.toArray();
double cosims[][] = cosimilarity(new int[]{pageId}, candidates);
Leaderboard top = new Leaderboard(maxResults);
for (int i = 0; i < candidates.length; i++) {
top.tallyScore(candidates[i], cosims[0][i]);
}
return top.getTop();
}
@Override
public SRResultList mostSimilar(String phrase, int maxResults) throws DaoException {
return mostSimilar(phrase, maxResults, null);
}
@Override
public SRResultList mostSimilar(String phrase, int maxResults, TIntSet validIds) throws DaoException {
TIntSet candidateSet = new TIntHashSet();
for (SubMetric m : metrics) {
SRResultList rl = m.metric.mostSimilar(phrase, (int) (maxResults * numCandidateMultiplier), validIds);
if (rl != null) {
for (SRResult r : rl) {
candidateSet.add(r.getId());
}
}
}
// Hack: because there's no way to compare a phrase query and articles,
// we need to re-ask mostSimilar with the specified candidate list.
TIntDoubleMap scores = new TIntDoubleHashMap();
for (SubMetric m : metrics) {
// Hack: The bottom 20% all get the same (lowest) score.
SRResultList rl = m.metric.mostSimilar(phrase,
(int) Math.ceil(candidateSet.size() * 0.8),
candidateSet);
if (rl != null && rl.numDocs() > 0) {
TIntFloatMap subscores = rl.asTroveMap();
double minScore = rl.getScore(rl.numDocs() - 1) * 0.99;
for (int id : subscores.keys()) {
double s = minScore;
if (subscores.containsKey(id)) {
s = subscores.get(id);
if (Double.isInfinite(s) || Double.isNaN(s)) {
s = minScore;
}
}
s *= m.coefficient;
scores.adjustOrPutValue(id, s, s);
}
}
}
final Leaderboard top = new Leaderboard(maxResults);
scores.forEachEntry(new TIntDoubleProcedure() {
@Override
public boolean execute(int id, double score) {
top.tallyScore(id, score);
return true;
}
});
return top.getTop();
}
@Override
public double[][] cosimilarity(int[] wpRowIds, int[] wpColIds) throws DaoException {
double result[][] = new double[wpRowIds.length][wpColIds.length];
for (SubMetric m : metrics) {
double r[][] = m.metric.cosimilarity(wpRowIds, wpColIds);
for (int i = 0; i < wpRowIds.length; i++) {
for (int j = 0; j < wpColIds.length; j++) {
double s = r[i][j];
if (!Double.isNaN(s) && !Double.isInfinite(s)) {
result[i][j] += s * m.coefficient;
}
}
}
}
return result;
}
@Override
public double[][] cosimilarity(String[] rowPhrases, String[] colPhrases) throws DaoException {
double result[][] = new double[rowPhrases.length][colPhrases.length];
for (SubMetric m : metrics) {
double r[][] = m.metric.cosimilarity(rowPhrases, colPhrases);
for (int i = 0; i < rowPhrases.length; i++) {
for (int j = 0; j < colPhrases.length; j++) {
double s = r[i][j];
if (!Double.isNaN(s) && !Double.isInfinite(s)) {
result[i][j] += s * m.coefficient;
}
}
}
}
return result;
}
@Override
public double[][] cosimilarity(int[] ids) throws DaoException {
return cosimilarity(ids, ids);
}
@Override
public double[][] cosimilarity(String[] phrases) throws DaoException {
return cosimilarity(phrases, phrases);
}
public void setTrainSubmetrics(boolean trainSubmetrics) {
this.trainSubmetrics = trainSubmetrics;
}
@Override
public Normalizer getMostSimilarNormalizer() { return null; }
@Override
public void setMostSimilarNormalizer(Normalizer n) {}
@Override
public Normalizer getSimilarityNormalizer() { return null; }
@Override
public void setSimilarityNormalizer(Normalizer n) {}
@Override
public File getDataDir() { return null; }
@Override
public void setDataDir(File dir) {}
@Override
public void write() throws IOException {}
@Override
public void read() throws IOException {}
@Override
public void trainSimilarity(Dataset dataset) throws DaoException {
if (trainSubmetrics) {
for (SubMetric m : metrics) m.metric.trainSimilarity(dataset);
}
}
@Override
public void trainMostSimilar(Dataset dataset, int numResults, TIntSet validIds) {
if (trainSubmetrics) {
for (SubMetric m : metrics) m.metric.trainMostSimilar(dataset, numResults, validIds);
}
}
@Override
public boolean similarityIsTrained() {
if (trainSubmetrics) {
for (SubMetric m : metrics) if (!m.metric.similarityIsTrained()) return false;
}
return true;
}
@Override
public boolean mostSimilarIsTrained() {
if (trainSubmetrics) {
for (SubMetric m : metrics) if (!m.metric.mostSimilarIsTrained()) return false;
}
return true;
}
public void setNumCandidateMultiplier(double numCandidateMultiplier) {
this.numCandidateMultiplier = numCandidateMultiplier;
}
public static class Provider extends org.wikibrain.conf.Provider<SRMetric>{
public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
super(configurator, config);
}
@Override
public Class getType() {
return SRMetric.class;
}
@Override
public String getPath() {
return "sr.metric.local";
}
@Override
public SRMetric get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException{
if (!config.getString("type").equals("simple-ensemble")) {
return null;
}
if (runtimeParams == null || !runtimeParams.containsKey("language")) {
throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter");
}
Language language = Language.getByLangCode(runtimeParams.get("language"));
if (!config.hasPath("metrics")){
throw new ConfigurationException("Ensemble metric has no base metrics to use.");
}
List<String> metricNames = config.getStringList("metrics");
List<Double> allCoefficients = config.getDoubleList("coefficients");
List<SRMetric> metrics = new ArrayList<SRMetric>();
List<Double> activeCoefficients = new ArrayList<Double>();
for (int i = 0; i < metricNames.size(); i++) {
try {
metrics.add(getConfigurator().get(SRMetric.class, metricNames.get(i),
"language", language.getLangCode()));
activeCoefficients.add(allCoefficients.get(i));
} catch (Exception e) {
LOG.error("Loading of metric " + metricNames.get(i) + " failed. Skipping it! Error:", e);
}
}
return new SimpleEnsembleMetric(name, language, metrics, activeCoefficients);
}
}
}