package org.wikibrain.sr.vector;
import gnu.trove.set.TIntSet;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.core.cmd.Env;
import org.wikibrain.core.cmd.EnvBuilder;
import org.wikibrain.matrix.DenseMatrix;
import org.wikibrain.matrix.knn.KDTreeKNN;
import org.wikibrain.matrix.knn.KNNFinder;
import org.wikibrain.matrix.knn.KmeansKNNFinder;
import org.wikibrain.matrix.knn.RandomProjectionKNNFinder;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResultList;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
/**
* @author Shilad Sen
*/
public class CompareDenseKnnAccelerators {
private final DenseVectorSRMetric sr;
private final Env env;
private final DenseMatrix matrix;
private int sampleSize = 10;
public CompareDenseKnnAccelerators(Env env) throws ConfigurationException {
this.env = env;
this.sr = (DenseVectorSRMetric) env.getConfigurator().get(SRMetric.class, "word2vec", "language", env.getDefaultLanguage().getLangCode());
this.matrix = sr.getGenerator().getFeatureMatrix();
}
public void evaluateRandomProjections() throws IOException {
RandomProjectionKNNFinder rp = new RandomProjectionKNNFinder(matrix);
rp.build();
evaluate(rp);
}
public void evaluateKMeansTree() throws IOException {
KmeansKNNFinder rp = new KmeansKNNFinder(matrix);
rp.build();
evaluate(rp);
}
public void evaluateKDTree() throws IOException {
KDTreeKNN rp = new KDTreeKNN(matrix);
rp.build();
evaluate(rp);
}
public void evaluate(KNNFinder finder) throws IOException {
for (int multiplier : Arrays.asList(1, 5, 10, 20, 50, 100, 1000)) {
for (int k : Arrays.asList(1, 10, 20, 50, 100, 1000)) {
evaluate(finder, k, multiplier);
}
}
}
public void evaluate(KNNFinder finder, int k, int multiplier) throws IOException {
sr.setAcceleratorMultiplier(multiplier);
Random rand = new Random();
long elapsedEstimated = 0;
long elapsedActual = 0;
int total = 0;
int hits = 0;
for (int i = 0; i < sampleSize; i++) {
int id = matrix.getRowIds()[rand.nextInt(matrix.getNumRows())];
float [] vec = matrix.getRow(id).getValues();
long t1 = System.currentTimeMillis();
sr.setAccelerator(finder);
SRResultList estimated = sr.mostSimilar(vec, k, null);
long t2 = System.currentTimeMillis();
sr.setAccelerator(null);
SRResultList actual = sr.mostSimilar(vec, k, null);
long t3 = System.currentTimeMillis();
if (actual == null || estimated == null) {
continue;
}
TIntSet overlap = actual.asTroveMap().keySet();
overlap.retainAll(estimated.asTroveMap().keySet());
total += actual.numDocs();
hits += overlap.size();
elapsedEstimated += (t2 - t1);
elapsedActual += (t3 - t2);
}
System.out.format(
"Results for k=%d with multiplier=%d: Precision %3f, accel millis=%3f naive millis=%3f ratio=%.3f\n",
k, multiplier, 1.0 * hits / total,
1.0 * elapsedEstimated / sampleSize, 1.0 * elapsedActual / sampleSize,
1.0 * elapsedActual / elapsedEstimated
);
}
public static void main(String args[]) throws ConfigurationException, IOException {
Env env = EnvBuilder.envFromArgs(args);
CompareDenseKnnAccelerators cmp = new CompareDenseKnnAccelerators(env);
cmp.evaluateKDTree();
cmp.evaluateRandomProjections();
// cmp.evaluateKMeansTree();
}
}