package org.wikibrain.matrix.knn;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.TIntSet;
import org.wikibrain.matrix.DenseMatrix;
import org.wikibrain.matrix.DenseMatrixRow;
import java.io.File;
import java.io.IOException;
import java.util.*;
/**
* A fast neighborhood finder for dense vectors.
*
* @author Shilad Sen
*/
public class KmeansKNNFinder implements KNNFinder {
private final DenseMatrix matrix;
private int sampleSize = 50000;
private int maxLeaf = 20;
private int branchingFactor = 5;
private Node root;
public KmeansKNNFinder(DenseMatrix matrix) {
this.matrix = matrix;
}
@Override
public void build() throws IOException {
root = new Node("R");
root.members.addAll(getSample());
root.build();
for (DenseMatrixRow row : matrix) {
root.place(row);
}
}
private static class Candidate implements Comparable<Candidate> {
Node n;
double score;
public Candidate(Node n, double score) {
this.n = n;
this.score = score;
}
@Override
public int compareTo(Candidate o) {
return Double.compare(score, o.score);
}
}
@Override
public Neighborhood query(float[] vector, int k, int maxTraversal, TIntSet validIds) {
if (validIds != null) {
throw new UnsupportedOperationException();
}
NeighborhoodAccumulator accum = new NeighborhoodAccumulator(k);
TreeSet<Candidate> work = new TreeSet<Candidate>();
work.add(new Candidate(root, -1.0));
int traversed = 0;
while (!work.isEmpty()) {
Node n = work.pollLast().n;
for (int rowId : n.memberIds.toArray()) {
DenseMatrixRow row = null;
try {
row = matrix.getRow(rowId);
} catch (IOException e) {
throw new IllegalStateException(e);
}
double sim = cosine(vector, row);
accum.visit(row.getRowIndex(), sim);
traversed++;
}
if (traversed >= maxTraversal) {
break;
}
if (n.children != null) {
for (Node c : n.children) {
work.add(new Candidate(c, cosine(vector, c.delegate)));
}
}
}
return accum.get();
}
@Override
public void save(File path) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public boolean load(File path) throws IOException {
throw new UnsupportedOperationException();
}
public void setSampleSize(int sampleSize) {
this.sampleSize = sampleSize;
}
public void setMaxLeaf(int maxLeaf) {
this.maxLeaf = maxLeaf;
}
public void setBranchingFactor(int branchingFactor) {
this.branchingFactor = branchingFactor;
}
private List<DenseMatrixRow> getSample() throws IOException {
List<Integer> ids = new ArrayList<Integer>();
for (int id : matrix.getRowIds()) {
ids.add(id);
}
Collections.shuffle(ids);
if (ids.size() > sampleSize) {
ids = ids.subList(0, sampleSize);
}
List<DenseMatrixRow> sample = new ArrayList<DenseMatrixRow>();
for (int id : ids) {
sample.add(matrix.getRow(id));
}
return sample;
}
class Node {
String path;
DenseMatrixRow delegate;
Node[] children = null;
TIntList memberIds;
List<DenseMatrixRow> members = new ArrayList<DenseMatrixRow>();
Node(String path) { this.path = path; }
void build() {
// System.out.println("building node with " + members.size());
if (members.size() <= maxLeaf) {
endBuild();
return;
}
initializeRandomly();
for (Node n : children) {
n.updateCenter();
}
double prevScore = 0.000000001;
for (int i = 0; i < 5; i++) {
double score = reallocateMembers();
// System.out.println(path + " score at iteration " + i + " is " + score);
if (score / prevScore - 1.0 < 0.001) {
break;
}
for (Node n : children) {
n.updateCenter();
}
prevScore = score;
}
endBuild();
for (Node n : children) {
n.build();
}
}
void place(DenseMatrixRow row) {
// If we're a leaf
if (children == null) {
memberIds.add(row.getRowIndex());
return;
}
// Otherwise find closest child.
findClosestChild(row).place(row);
}
private void endBuild() {
members = null;
memberIds = new TIntArrayList();
}
private void initializeRandomly() {
children = new Node[branchingFactor];
for (int i = 0; i < children.length; i++) {
children[i] = new Node(path + i);
}
Collections.shuffle(members);
for (int i = 0; i < members.size(); i++) {
children[i % branchingFactor].members.add(members.get(i));
}
}
private double updateCenter() {
if (members.isEmpty()) {
delegate = null;
return 0.0;
}
// Calculate a new centroid
double center [] = new double[members.get(0).getNumCols()];
for (DenseMatrixRow m : members) {
for (int i = 0; i < center.length; i++) {
center[i] += m.getColValue(i);
}
}
for (int i = 0; i < center.length; i++) {
center[i] /= members.size();
}
// Pick the best delegate.
double compactness = 0.0;
double mostSimilar = -10;
for (DenseMatrixRow m : members) {
double s = cosine(center, m);
compactness += s;
if (s > mostSimilar) {
mostSimilar = s;
delegate = m;
}
}
return compactness / members.size();
}
private double reallocateMembers() {
for (Node n : children) {
n.members.clear();
}
double score = 0.0;
for (DenseMatrixRow m : members) {
Node best = findClosestChild(m);
score += best.similarity(m);
best.members.add(m);
}
return score / members.size();
}
private Node findClosestChild(DenseMatrixRow row) {
double bestSim = -10;
Node best = null;
for (Node n : children) {
double s = n.similarity(row);
if (s > bestSim) {
best = n;
bestSim = s;
}
}
if (best == null) {
throw new IllegalStateException();
}
return best;
}
private double similarity(DenseMatrixRow row) {
return cosine(delegate, row);
}
private double similarity(float [] v) {
return cosine(v, delegate);
}
}
static double cosine(DenseMatrixRow X, DenseMatrixRow Y) {
if (X == null || Y == null) {
return 0;
}
double xDotX = 0.0;
double yDotY = 0.0;
double xDotY = 0.0;
for (int i = 0; i < X.getNumCols(); i++) {
double x = X.getColValue(i);
double y = Y.getColValue(i);
xDotX += x * x;
yDotY += y * y;
xDotY += x * y;
}
if (xDotX * yDotY == 0) {
return 0.0;
}
return xDotY / Math.sqrt(xDotX * yDotY);
}
static double cosine(double [] X, DenseMatrixRow Y) {
if (X == null || Y == null) {
return 0;
}
double xDotX = 0.0;
double yDotY = 0.0;
double xDotY = 0.0;
for (int i = 0; i < X.length; i++) {
double x = X[i];
double y = Y.getColValue(i);
xDotX += x * x;
yDotY += y * y;
xDotY += x * y;
}
if (xDotX * yDotY == 0) {
return 0.0;
}
return xDotY / Math.sqrt(xDotX * yDotY);
}
static double cosine(float [] X, DenseMatrixRow Y) {
if (X == null || Y == null) {
return 0;
}
double xDotX = 0.0;
double yDotY = 0.0;
double xDotY = 0.0;
for (int i = 0; i < X.length; i++) {
double x = X[i];
double y = Y.getColValue(i);
xDotX += x * x;
yDotY += y * y;
xDotY += x * y;
}
if (xDotX * yDotY == 0) {
return 0.0;
}
return xDotY / Math.sqrt(xDotX * yDotY);
}
}