package org.wikibrain.sr.ensemble;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.apache.commons.io.FileUtils;
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import java.io.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
*@author Matt Lesicko
*/
public class LinearEnsemble implements Ensemble{
private static final Logger LOG = LoggerFactory.getLogger(LinearEnsemble.class);
final int numMetrics;
private final int numTrainingCandidateArticles;
TDoubleArrayList simlarityCoefficients;
TDoubleArrayList mostSimilarCoefficients;
Interpolator similarityInterpolator;
Interpolator mostSimilarInterpolator;
public LinearEnsemble(int numMetrics, int numTrainingCandidateArticles){
this.numTrainingCandidateArticles = numTrainingCandidateArticles;
this.numMetrics = numMetrics;
simlarityCoefficients = new TDoubleArrayList();
simlarityCoefficients.add(0.0);
for (int i=0; i<numMetrics; i++){
simlarityCoefficients.add(1.0 / numMetrics);
}
mostSimilarCoefficients = new TDoubleArrayList();
mostSimilarCoefficients.add(0.0);
for (int i=0; i<numMetrics; i++){
mostSimilarCoefficients.add(1.0/numMetrics);
mostSimilarCoefficients.add(0);
}
similarityInterpolator = new Interpolator(numMetrics);
mostSimilarInterpolator = new Interpolator(numMetrics);
}
public String getName(){
return "LinearEnsemble";
}
@Override
public void trainSimilarity(List<EnsembleSim> simList) {
if (simList.isEmpty()) {
throw new IllegalArgumentException("no examples to train on!");
}
similarityInterpolator.trainSimilarity(simList);
double[][] X = new double[simList.size()][numMetrics];
double[] Y = new double[simList.size()];
for (int i = 0; i<simList.size(); i++){
Y[i]=simList.get(i).knownSim.similarity;
EnsembleSim es = similarityInterpolator.interpolate(simList.get(i));
for (int j=0; j<numMetrics; j++){
X[i][j]=es.getScores().get(j);
}
}
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
regression.newSampleData(Y, X);
simlarityCoefficients = new TDoubleArrayList(regression.estimateRegressionParameters());
double pearson = Math.sqrt(regression.calculateRSquared());
LOG.info("coefficients are " + simlarityCoefficients.toString());
LOG.info("pearson for multiple regression is " + pearson);
}
@Override
public void trainMostSimilar(List<EnsembleSim> simList) {
if (simList.isEmpty()){
throw new IllegalStateException("no examples to train on!");
}
mostSimilarInterpolator.trainMostSimilar(simList);
// Remove things that have no observed metrics
List<EnsembleSim> pruned = new ArrayList<EnsembleSim>();
for (EnsembleSim es : simList) {
if (es != null && es.getNumMetricsWithScore() > 0) {
pruned.add(es);
}
}
double[][] X = new double[pruned.size()][numMetrics*2];
double[] Y = new double[pruned.size()];
for (int i=0; i<pruned.size(); i++){
Y[i]=pruned.get(i).knownSim.similarity;
EnsembleSim es = mostSimilarInterpolator.interpolate(pruned.get(i));
for (int j=0; j<numMetrics; j++){
X[i][2*j]= es.getScores().get(j);
X[i][2*j+1]= Math.log(es.getRanks().get(j)+1);
}
}
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
regression.newSampleData(Y,X);
mostSimilarCoefficients = new TDoubleArrayList(regression.estimateRegressionParameters());
double pearson = Math.sqrt(regression.calculateRSquared());
LOG.info("coefficients are "+mostSimilarCoefficients.toString());
LOG.info("pearson for multiple regression is "+pearson);
}
@Override
public SRResult predictSimilarity(List<SRResult> scores) {
if (scores.size()+1!= simlarityCoefficients.size()){
throw new IllegalStateException();
}
double weightedScore = simlarityCoefficients.get(0);
for (int i=0; i<scores.size(); i++){
double s = scores.get(i) == null ? Double.NaN : scores.get(i).getScore();
if (Double.isNaN(s) || Double.isInfinite(s)) {
s = similarityInterpolator.getInterpolatedScore(i);
}
weightedScore+=(s * simlarityCoefficients.get(i+1));
}
return new SRResult(weightedScore);
}
public static boolean debug = false;
@Override
public SRResultList predictMostSimilar(List<SRResultList> scores, int maxResults, TIntSet validIds) {
if (2*scores.size()+1!= mostSimilarCoefficients.size()){
throw new IllegalStateException();
}
TIntSet allIds = new TIntHashSet(); // ids returned by at least one metric
for (SRResultList resultList : scores){
if (resultList != null) {
for (SRResult result : resultList){
allIds.add(result.getId());
}
}
}
TIntDoubleHashMap scoreMap = new TIntDoubleHashMap();
for (int id : allIds.toArray()) {
scoreMap.put(id, mostSimilarCoefficients.get(0));
}
int i =1;
for (SRResultList resultList : scores){
TIntSet unknownIds = new TIntHashSet(allIds);
double c1 = mostSimilarCoefficients.get(i); // score coeff
double c2 = mostSimilarCoefficients.get(i+1); // rank coefficient
if (resultList != null) {
for (int j = 0; j < resultList.numDocs(); j++) {
int rank = j + 1;
// expand or contract ranks proportionately
if (validIds != null) {
double k = 1.0 * numTrainingCandidateArticles / validIds.size();
rank = (int) (rank * k);
}
SRResult result = resultList.get(j);
unknownIds.remove(result.getId());
double value = c1 * result.getScore() + c2 * Math.log(rank);
if (debug) {
System.err.format("%s %d. %.3f (id=%d), computing %.3f * %.3f + %.3f * (log(%d) = %.3f)\n",
"m" + i, j, value, result.getId(),
c1, result.getScore(), c2, rank, Math.log(rank));
}
scoreMap.adjustValue(result.getId(), value);
}
}
// interpolate scores for unknown ids
double value = c1 * mostSimilarInterpolator.getInterpolatedScore(i/2)
+ c2 * Math.log(mostSimilarInterpolator.getInterpolatedRank(i/2));
for (int id : unknownIds.toArray()) {
scoreMap.adjustValue(id, value);
}
i+=2;
}
List<SRResult> resultList = new ArrayList<SRResult>();
for (int id : scoreMap.keys()){
resultList.add(new SRResult(id,scoreMap.get(id)));
}
Collections.sort(resultList);
Collections.reverse(resultList);
int size = maxResults>resultList.size()? resultList.size() : maxResults;
SRResultList result = new SRResultList(size);
for (i=0; i<size;i++){
result.set(i,resultList.get(i));
}
return result;
}
@Override
public void read(String path) throws IOException {
File dir = FileUtils.getFile(path, "ensemble", getName());
if (!dir.isDirectory()) {
return;
}
try {
ObjectInputStream oip = new ObjectInputStream(
new FileInputStream(new File(dir, "similarityCoefficients")));
this.simlarityCoefficients = (TDoubleArrayList)oip.readObject();
oip.close();
oip = new ObjectInputStream(
new FileInputStream(new File(dir, "mostSimilarCoefficients")));
this.mostSimilarCoefficients = (TDoubleArrayList)oip.readObject();
oip.close();
oip = new ObjectInputStream(
new FileInputStream(new File(dir, "similarityInterpolator")));
this.similarityInterpolator = (Interpolator) oip.readObject();
oip.close();
oip = new ObjectInputStream(
new FileInputStream(new File(dir, "mostSimilarInterpolator")));
this.mostSimilarInterpolator = (Interpolator) oip.readObject();
oip.close();
} catch (ClassNotFoundException e){
throw new IOException("Malformed coefficient file(s)",e);
}
}
@Override
public void write(String path) throws IOException{
File dir = FileUtils.getFile(path, "ensemble", getName());
if (!dir.isDirectory()) {
dir.mkdirs();
}
ObjectOutputStream oop = new ObjectOutputStream(
new FileOutputStream(new File(dir, "similarityCoefficients")));
oop.writeObject(simlarityCoefficients);
oop.flush();
oop.close();
oop = new ObjectOutputStream(
new FileOutputStream(new File(dir, "mostSimilarCoefficients")));
oop.writeObject(mostSimilarCoefficients);
oop.flush();
oop.close();
oop = new ObjectOutputStream(
new FileOutputStream(new File(dir, "similarityInterpolator")));
oop.writeObject(similarityInterpolator);
oop.flush();
oop.close();
oop = new ObjectOutputStream(
new FileOutputStream(new File(dir, "mostSimilarInterpolator")));
oop.writeObject(mostSimilarInterpolator);
oop.flush();
oop.close();
}
}