package org.wikibrain.matrix.knn;
import gnu.trove.procedure.TIntProcedure;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.wikibrain.matrix.DenseMatrix;
import org.wikibrain.matrix.DenseMatrixRow;
import java.io.*;
import java.util.Arrays;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Brute force implementations of knn for a dense matrix.
*
* @author Shilad Sen
*/
public class BruteForceKNNFinder implements KNNFinder {
private static final Logger LOG = LoggerFactory.getLogger(BruteForceKNNFinder.class);
private final DenseMatrix matrix;
public BruteForceKNNFinder(DenseMatrix matrix) throws IOException {
this.matrix = matrix;
}
@Override
public Neighborhood query(final float[] vector, int k, int maxTraversal, TIntSet validIds) {
final NeighborhoodAccumulator accum = new NeighborhoodAccumulator(k);
if (validIds == null) {
for (DenseMatrixRow row : matrix) {
double sim = row.dot(vector);
accum.visit(row.getRowIndex(), sim);
}
} else {
validIds.forEach(new TIntProcedure() {
@Override
public boolean execute(int id) {
DenseMatrixRow row = null;
try {
row = matrix.getRow(id);
} catch (IOException e) {
throw new RuntimeException(e);
}
double sim = KmeansKNNFinder.cosine(vector, row);
accum.visit(id, sim);
return true;
}
});
}
return accum.get();
}
@Override
public void build() throws IOException {}
@Override
public void save(File path) throws IOException {}
@Override
public boolean load(File path) throws IOException { return true; }
}