package org.wikibrain.matrix.knn; import gnu.trove.set.TIntSet; import gnu.trove.set.hash.TIntHashSet; import org.junit.Test; import org.wikibrain.matrix.DenseMatrix; import org.wikibrain.matrix.DenseMatrixRow; import org.wikibrain.matrix.DenseMatrixWriter; import org.wikibrain.matrix.ValueConf; import java.io.File; import java.io.IOException; import java.util.Random; import static junit.framework.Assert.assertEquals; /** * @author Shilad Sen */ public class TestKMeansNNFinder { DenseMatrix matrix; KNNFinder finder; private void makeSmall() throws IOException { matrix = createMatrix(1000, 20); KmeansKNNFinder km = new KmeansKNNFinder(matrix); km.setMaxLeaf(10); km.setSampleSize(200); km.setBranchingFactor(5); km.build(); finder = km; } private void makeBig() throws IOException { matrix = createMatrix(100000, 20); KmeansKNNFinder km = new KmeansKNNFinder(matrix); km.setMaxLeaf(10); km.setSampleSize(10000); km.setBranchingFactor(5); km.build(); finder = km; } @Test public void testBuild() throws IOException { makeSmall(); } @Test public void testQueryCoverage() throws IOException { makeSmall(); int hits = 0; int iters = 10; for (int i = 0; i < iters; i++) { System.out.println("doing " + i); float[] v = randomVector(20); Neighborhood estimated = finder.query(v, 10, 1000, null); Neighborhood actual = actualNeighbors(v, matrix, 10); hits += overlap(estimated, actual); } assertEquals(iters * 10, hits); } @Test public void testQuery() throws IOException { makeBig(); int hits = 0; long elapsedTree = 0; long elapsedBruteForce = 0; int iters = 10; for (int i = 0; i < iters; i++) { System.out.println("doing " + i); float[] v = randomVector(20); long t1 = System.currentTimeMillis(); Neighborhood estimated = finder.query(v, 10, 2000, null); long t2 = System.currentTimeMillis(); Neighborhood actual = actualNeighbors(v, matrix, 10); long t3 = System.currentTimeMillis(); elapsedBruteForce += (t3 - t2); elapsedTree += (t2 - t1); hits += overlap(estimated, actual); } System.out.println("Mean overlap is " + (1.0 * hits / iters)); System.out.println("Mean bruteforce millis is " + (1.0 * elapsedBruteForce / iters)); System.out.println("Mean optimized millis is " + (1.0 * elapsedTree / iters)); } private int overlap(Neighborhood n1, Neighborhood n2) { TIntSet ids = new TIntHashSet(); for (int i = 0; i < n1.size(); i++) { ids.add(n1.getId(i)); } int result = 0; for (int i = 0; i < n2.size(); i++) { if (ids.contains(n2.getId(i))) result++; } return result; } private static Neighborhood actualNeighbors(float [] v, DenseMatrix matrix, int n) { NeighborhoodAccumulator accum = new NeighborhoodAccumulator(n); for (DenseMatrixRow row : matrix) { double sim = KmeansKNNFinder.cosine(v, row); accum.visit(row.getRowIndex(), sim); } return accum.get(); } private static DenseMatrix createMatrix(int rows, int cols) throws IOException { File tmp = File.createTempFile("knnfinder", ".matrix"); tmp.delete(); ValueConf vconf = new ValueConf(); int [] colIds = new int[cols]; for (int i= 0 ; i < cols; i++) { colIds[i] = i; } DenseMatrixWriter writer = new DenseMatrixWriter(tmp, vconf); for (int i = 0; i < rows; i++) { writer.writeRow(new DenseMatrixRow(vconf, i, colIds, randomVector(cols))); } writer.finish(); tmp.deleteOnExit(); return new DenseMatrix(tmp); } private static float[] randomVector(int cols) { Random rand = new Random(); float [] vals = new float[cols]; for (int j = 0; j < cols; j++) { vals[j] = rand.nextFloat(); } return vals; } }