package org.wikibrain.matrix.knn;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.wikibrain.matrix.DenseMatrix;
import org.wikibrain.matrix.DenseMatrixRow;
import java.io.*;
import java.util.Arrays;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Approximation cache for nearest neighbors that uses a random-projection scheme.
* This implementation keeps two longs (16 bytes) per row in the dense matrix.
* It constructs one random vector for each of the 128 bits in the two longs.
* Each bit in a particular row's two longs indicates the sign (-1 = 0, +1 = 1) of the
* dot product.
*
* To find neighbors, the algorithm counts how many of the 128 bits agree between
* a query and a candidate.
*
* @author Shilad Sen
*/
public class RandomProjectionKNNFinder implements KNNFinder {
private static final Logger LOG = LoggerFactory.getLogger(RandomProjectionKNNFinder.class);
public static final int NUM_BITS = 128;
private final DenseMatrix matrix;
private final int dimensions;
private long [] bits; // two longs per matrix entry
private int [] ids;
private double [][] vectors;
// Sampled mean and standard deviation
private double [] means;
private double[] devs;
public RandomProjectionKNNFinder(DenseMatrix matrix) throws IOException {
this.matrix = matrix;
this.ids = matrix.getRowIds();
this.dimensions = matrix.getRow(ids[0]).getNumCols();
}
@Override
public void build() throws IOException {
makeVectors();
bits = new long[ids.length*2];
long vbits[] = new long[2];
for (int i = 0; i < ids.length; i++) {
float [] v = matrix.getRow(ids[i]).getValues();
project(v, vbits);
bits[i*2] = vbits[0];
bits[i*2+1] = vbits[1];
}
}
private void makeVectors() throws IOException {
// Sample the mean of each dimension
means = new double[dimensions];
int n = Math.min(5000, ids.length);
for (int i = 0; i < n; i++) {
float vals[] = matrix.getRow(ids[i]).getValues();
if (vals.length != dimensions) throw new IllegalStateException();
for (int d = 0; d < vals.length; d++) {
means[d] += vals[d];
}
}
for (int j= 0; j < dimensions; j++) {
means[j] /= n;
}
// Sample the standard deviation of each dimension
devs = new double[dimensions];
Arrays.fill(devs, 0.0001); // avoid divide by zero in normalization procedure
for (int i = 0; i < n; i++) {
float vals[] = matrix.getRow(ids[i]).getValues();
for (int d = 0; d < vals.length; d++) {
devs[d] += (vals[d] - means[d]) * (vals[d] - means[d]);
}
}
for (int d= 0; d < dimensions; d++) {
devs[d] = Math.sqrt(devs[d] / n);
LOG.debug("dimension " + d + " has mean " + means[d] + " and std-dev " + devs[d]);
}
Random random = new Random();
vectors = new double[NUM_BITS][dimensions];
double norms[] = new double[NUM_BITS];
for (int d = 0; d < dimensions; d++) {
for (int i = 0; i < vectors.length; i++) {
vectors[i][d] = random.nextGaussian() / 2;
norms[i] += vectors[i][d] * vectors[i][d];
}
}
for (int i = 0; i < vectors.length; i++) {
for (int d = 0; d < dimensions; d++) {
vectors[i][d] /= (norms[i] + 0.000001);
}
}
}
private void project(float [] v, long [] result) {
if (v.length != dimensions) {
throw new IllegalArgumentException("Expected " + dimensions + " dimensions, found " + v.length);
}
double[] v2 = new double[dimensions];
for (int d = 0; d < dimensions; d++) {
v2[d] = (v[d] - means[d]) / devs[d];
}
long bits1 = 0;
for (int i = 0; i < vectors.length/2; i++) {
double s = dot(vectors[i], v2);
if (s > 0) {
bits1 |= (1l << i);
}
}
long bits2 = 0;
for (int i = vectors.length/2; i < vectors.length; i++) {
double s = dot(vectors[i], v2);
if (s > 0) {
bits2 |= (1l << i);
}
}
result[0] = bits1;
result[1] = bits2;
}
private double dot(double [] v1, double [] v2) {
double sum = 0.0;
for (int i = 0; i < v1.length; i++) {
sum += v1[i] * v2[i];
}
return sum;
}
@Override
public Neighborhood query(float[] vector, int k, int maxTraversal, TIntSet validIds) {
// Hack: Speed up query by reducing number of collisions
if (validIds != null) {
TIntSet tmp = validIds;
validIds = new TIntHashSet(tmp.size() * 4);
validIds.addAll(tmp);
}
long vbits[] = new long[2];
project(vector, vbits);
long p0 = vbits[0];
long p1 = vbits[1];
// Pass 1: count how many things have each # of bits.
int[] numHits = new int[NUM_BITS + 1];
for (int i = 0; i < ids.length; i++) {
if (validIds != null && !validIds.contains(ids[i])) continue;
int nSet = NUM_BITS - Long.bitCount(bits[2*i] ^ p0) - Long.bitCount(bits[2*i+1] ^ p1);
numHits[nSet]++;
}
// System.out.println("distribution is " + Arrays.toString(numHits));
// Pick the threshold we need to consider.
int threshold;
int count = 0;
for (threshold = NUM_BITS; threshold > 0; threshold --) {
count += numHits[threshold];
if (count >= maxTraversal) {
break;
}
}
// System.out.println("set threshold at at least " + threshold + " bits in common");
NeighborhoodAccumulator accum = new NeighborhoodAccumulator(k);
for (int i = 0; i < ids.length; i++) {
if (validIds != null && !validIds.contains(ids[i])) continue;
int nSet = NUM_BITS - Long.bitCount(bits[2*i] ^ p0) - Long.bitCount(bits[2*i+1] ^ p1);
if (nSet >= threshold) {
try {
DenseMatrixRow row = matrix.getRow(ids[i]);
double sim = KmeansKNNFinder.cosine(vector, row);
accum.visit(ids[i], sim);
} catch (IOException e) {
throw new IllegalStateException(e);
}
}
}
return accum.get();
}
@Override
public void save(File path) throws IOException {
path.getParentFile().mkdirs();
ObjectOutputStream oop = new ObjectOutputStream(new FileOutputStream(path));
oop.writeObject(new Object[] { vectors, bits, means, devs});
oop.close();
}
@Override
public boolean load(File path) throws IOException {
if (!path.isFile()) {
LOG.warn("Not loading knn model. File doesn't exist: " + path);
return false;
} else if (path.lastModified() < matrix.getPath().lastModified()) {
LOG.warn("Not loading knn model. File " + path + " older than matrix: " + matrix.getPath());
return false;
}
ObjectInputStream in = new ObjectInputStream(new FileInputStream(path));
try {
Object [] obj = (Object[]) in.readObject();
double [][] newVectors = (double[][]) obj[0];
long [] newBits = (long[]) obj[1];
double [] newMeans = (double[]) obj[2];
double [] newDevs = (double[]) obj[3];
if (newBits.length != ids.length *2) {
LOG.warn("Not loading knn model. Expected " + 2*ids.length + " longs, found " + newBits.length);
return false;
}
if (newVectors.length != NUM_BITS || newVectors[0].length != dimensions) {
LOG.warn("Not loading knn model. Invalid vectors dimensions.");
return false;
}
if (newMeans.length != dimensions || newDevs.length != dimensions) {
LOG.warn("Not loading knn model. Invalid mean or devs dimensions.");
return false;
}
this.vectors =newVectors;
this.bits = newBits;
this.means = newMeans;
this.devs = newDevs;
return true;
} catch (ClassNotFoundException e) {
throw new IOException(e);
}
}
}