package org.wikibrain.matrix.knn; import gnu.trove.set.TIntSet; import org.wikibrain.matrix.DenseMatrix; import org.wikibrain.matrix.DenseMatrixRow; import java.io.File; import java.io.IOException; import java.util.Arrays; import java.util.Random; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * @author Shilad Sen */ public class LSHForestKNNFinder implements KNNFinder { private static final Logger LOG = LoggerFactory.getLogger(LSHForestKNNFinder.class); private static final int NUM_BITS = 16; // Number of bits in each mask (size of short) private int numTrees = 5; private short [][] bits; private final DenseMatrix matrix; private final int dimensions; private final int [] ids; private double [][][] vectors; // Sampled mean and standard deviation private double [] means; private double[] devs; public LSHForestKNNFinder(DenseMatrix matrix) throws IOException { this.matrix = matrix; this.ids = matrix.getRowIds(); this.dimensions = matrix.getRow(ids[0]).getNumCols(); } @Override public synchronized void build() throws IOException { analyzeSample(); bits = new short[numTrees][]; vectors = new double[numTrees][][]; for (int i = 0; i < numTrees; i++) { buildTree(i); } } private void analyzeSample() throws IOException { // Sample the mean of each dimension means = new double[dimensions]; int n = Math.min(5000, ids.length); for (int i = 0; i < n; i++) { float vals[] = matrix.getRow(ids[i]).getValues(); if (vals.length != dimensions) throw new IllegalStateException(); for (int d = 0; d < vals.length; d++) { means[d] += vals[d]; } } for (int j= 0; j < dimensions; j++) { means[j] /= n; } // Sample the standard deviation of each dimension devs = new double[dimensions]; Arrays.fill(devs, 0.0001); // avoid divide by zero in normalization procedure for (int i = 0; i < n; i++) { float vals[] = matrix.getRow(ids[i]).getValues(); for (int d = 0; d < vals.length; d++) { devs[d] += (vals[d] - means[d]) * (vals[d] - means[d]); } } for (int d= 0; d < dimensions; d++) { devs[d] = Math.sqrt(devs[d] / n); LOG.info("dimension " + d + " has mean " + means[d] + " and std-dev " + devs[d]); } } private void buildTree(int treeNum) throws IOException { double [][] V = new double[NUM_BITS][dimensions]; vectors[treeNum] = V; // Make vectors for this tree Random random = new Random(); double norms[] = new double[NUM_BITS]; for (int d = 0; d < dimensions; d++) { for (int i = 0; i < V.length; i++) { V[i][d] = random.nextGaussian() / 2; norms[i] += V[i][d] * V[i][d]; } } for (int i = 0; i < V.length; i++) { for (int d = 0; d < dimensions; d++) { V[i][d] /= (norms[i] + 0.000001); } } short [] B = new short[ids.length]; bits[treeNum] = B; for (int i = 0; i < ids.length; i++) { B[i] = project(treeNum, matrix.getRow(ids[i]).getValues()); } } private short project(int treeNum, float [] v) { if (v.length != dimensions) { throw new IllegalArgumentException("Expected " + dimensions + " dimensions, found " + v.length); } double [][] V = vectors[treeNum]; double[] v2 = new double[dimensions]; for (int d = 0; d < dimensions; d++) { v2[d] = (v[d] - means[d]) / devs[d]; } short result = 0; for (int i = 0; i < V.length; i++) { double s = dot(V[i], v2); if (s > 0) { result |= (((short)1) << i); } } return result; } private double dot(double [] v1, double [] v2) { double sum = 0.0; for (int i = 0; i < v1.length; i++) { sum += v1[i] * v2[i]; } return sum; } @Override public Neighborhood query(float[] vector, int k, int maxTraversal, TIntSet validIds) { short [] P = new short[numTrees]; // projections for (int i = 0; i < numTrees; i++) { P[i] = project(i, vector); } Random rand = new Random(); byte [] idMatchLens = new byte[ids.length]; Arrays.fill(idMatchLens, (byte) -1); int hist[] = new int[NUM_BITS+1]; for (int i = 0; i < ids.length; i++) { if (validIds != null && !validIds.contains(ids[i])) continue; int maxMatch = -1; // max bitwise prefix match for (int t = 0; t < numTrees; t++) { int m = Integer.numberOfLeadingZeros((P[t] ^ bits[t][i]) & 0xffff) - (32 - NUM_BITS); // System.out.format("Leading zeros for %d and %d with %s and %s with XOR %s is %d\n", // P[t], bits[t][i], // paddedShortBinary(P[t]), // paddedShortBinary(bits[t][i]), // paddedShortBinary((P[t] ^ bits[t][i])), m); maxMatch = Math.max(m, maxMatch); } // maxMatch = rand.nextInt(17); if (maxMatch < 0 || maxMatch > Byte.MAX_VALUE) throw new IllegalStateException(); idMatchLens[i] = (byte) maxMatch; hist[maxMatch] += 1; } // Pick the threshold we need to consider. int threshold; int count = 0; for (threshold = NUM_BITS; threshold > 0; threshold --) { count += hist[threshold]; if (count >= maxTraversal) { break; } } System.out.println("threshold is " + threshold + " for " + Arrays.toString(hist)); // Find all candidates within the threshold NeighborhoodAccumulator accum = new NeighborhoodAccumulator(k); for (int i = 0; i < ids.length; i++) { if (idMatchLens[i] >= threshold) { try { DenseMatrixRow row = matrix.getRow(ids[i]); double sim = KmeansKNNFinder.cosine(vector, row); accum.visit(ids[i], sim); } catch (IOException e) { throw new IllegalStateException(e); } } } return accum.get(); } @Override public void save(File path) throws IOException { throw new UnsupportedOperationException(); } @Override public boolean load(File path) throws IOException { throw new UnsupportedOperationException(); } private String paddedShortBinary(int s) { return String.format("%16s", Integer.toBinaryString(s & 0xffff)).replace(' ', '0'); } }