package org.wikibrain.sr.phrasesim;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.TLongFloatMap;
import gnu.trove.map.hash.TLongFloatHashMap;
import gnu.trove.procedure.TIntFloatProcedure;
import org.wikibrain.sr.vector.SparseVectorSRMetric;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* TODO: make the individual metrics, and the linear combination of them, trainable.
*
* @author Shilad Sen
*/
public class EnsemblePhraseCreator implements PhraseCreator {
private static final Logger LOGGER = LoggerFactory.getLogger(EnsemblePhraseCreator.class);
private final SparseVectorSRMetric[] metrics;
private double coefficients[];
public EnsemblePhraseCreator(SparseVectorSRMetric metrics[], double coefficients[]) {
if (coefficients.length != metrics.length) {
throw new IllegalArgumentException();
}
this.metrics = metrics;
this.coefficients = coefficients;
}
@Override
public TLongFloatMap getVector(String phrase) {
final TLongFloatMap vector = new TLongFloatHashMap();
for (int i = 0; i < metrics.length; i++) {
try {
TIntFloatMap v = PhraseUtils.getPhraseVector(metrics[i], phrase);
if (v == null) continue;
final int finalI = i;
v.forEachEntry(new TIntFloatProcedure() {
@Override
public boolean execute(int key, float value) {
long id = ((long) finalI) * Integer.MAX_VALUE + key;
vector.put(id, (float) (value * coefficients[finalI]));
return true;
}
});
} catch (Exception e) {
LOGGER.info("SR call for " + phrase + ", metric " + metrics[i].getName() + " failed", e);
}
}
return vector;
}
}