package com.jujutsu.tsne.barneshut;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.concurrent.ThreadLocalRandom;
public class VpTree<StorageType> {
DataPoint [] _items;
Node _root;
Distance distance;
public VpTree() {
distance = new EuclideanDistance();
}
public VpTree(Distance distance) {
this.distance = distance;
}
public void create(DataPoint [] items) {
_items = items.clone();
_root = buildFromPoints(0,items.length);
}
public void search(DataPoint target, int k, List<DataPoint> results, List<Double> distances) {
// Use a priority queue to store intermediate results on
// Javas prio heap is by default in ascending order, we want descending...
PriorityQueue<HeapItem> heap = new PriorityQueue<HeapItem>(k,new Comparator<HeapItem>() {
@Override
public int compare(HeapItem o1, HeapItem o2) {
return -1 * o1.compareTo(o2);
}
});
// Variable that tracks the distance to the farthest point in our results
double tau = Double.MAX_VALUE;
// Perform the search
_root.search(_root, target, k, heap, tau);
// Gather final results
results.clear();
distances.clear();
while(!heap.isEmpty()) {
results.add(_items[heap.peek().index]);
distances.add(heap.peek().dist);
heap.remove();
}
// Results are in reverse order
Collections.reverse(results);
Collections.reverse(distances);
}
// Function that (recursively) fills the tree
public Node buildFromPoints( int lower, int upper )
{
if (upper == lower) { // indicates that we're done here!
return null;
}
// Lower index is center of current node
Node node = createNode();
node.index = lower;
if (upper - lower > 1) { // if we did not arrive at leaf yet
// Choose an arbitrary point and move it to the start
int i = (int) (ThreadLocalRandom.current().nextDouble() * (upper - lower - 1)) + lower;
swap(_items, lower, i);
// Partition around the median distance
int median = (upper + lower) / 2;
nth_element(_items, lower + 1, median, upper, new DistanceComparator(_items[lower],distance));
// Threshold of the new node will be the distance to the median
node.threshold = distance(_items[lower], _items[median]);
// Recursively build tree
node.index = lower;
node.left = buildFromPoints(lower + 1, median);
node.right = buildFromPoints(median, upper);
}
// Return result
return node;
}
protected VpTree<StorageType>.Node createNode() {
return new Node();
}
public Node getRoot() {
return _root;
}
// Quick and dirty... optimize later :D
static void nth_element(DataPoint [] array, int low, int mid, int high,
DistanceComparator distanceComparator) {
DataPoint [] tmp = new DataPoint[high-low];
for (int i = 0; i < tmp.length; i++) {
tmp[i] = array[low+i];
}
Arrays.sort(tmp, distanceComparator);
for (int i = 0; i < tmp.length; i++) {
array[low+i] = tmp[i];
}
}
static void nth_element(int [] array, int low, int mid, int high) {
int [] tmp = new int[high-low];
for (int i = 0; i < tmp.length; i++) {
tmp[i] = array[low+i];
}
Arrays.sort(tmp);
for (int i = 0; i < tmp.length; i++) {
array[low+i] = tmp[i];
}
}
public double distance(DataPoint dataPoint1, DataPoint dataPoint2) {
return distance.distance(dataPoint1, dataPoint2);
}
private void swap(DataPoint [] items, int idx1,int idx2) {
DataPoint dp = items[idx1];
items[idx1] = items[idx2];
items[idx2] = dp;
}
// An item on the intermediate result queue
static class HeapItem implements Comparable<HeapItem> {
int index;
double dist;
HeapItem( int index, double dist) {
this.index = index;
this.dist = dist;
}
@Override
public int compareTo(HeapItem o) {
return dist < o.dist ? -1 : (dist > o.dist ? 1 : 0);
}
@Override
public String toString() {
return "HeapItem (index=" + index + ",dist=" + dist + ")";
}
};
class Node {
int index;
double threshold;
protected Node left;
protected Node right;
@Override
public String toString() {
return "Node(id=" + index + ")";
}
public Node getLeft() {
return left;
}
public Node getRight() {
return right;
}
// Helper function that searches the tree
double search(Node node, DataPoint target, int k, Queue<HeapItem> heap, double _tau)
{
if(node == null) return _tau; // indicates that we're done here
// Compute distance between target and current node
double dist = distance(_items[node.index], target);
// If current node within radius tau
if(dist < _tau) {
if(heap.size() == k) heap.remove(); // remove farthest node from result list (if we already have k results)
heap.add(new HeapItem(node.index, dist)); // add current node to result list
if(heap.size() == k) _tau = heap.peek().dist; // update value of tau (farthest point in result list)
}
// Return if we arrived at a leaf
if(node.left == null && node.right == null) {
return _tau;
}
// If the target lies within the radius of ball
if(dist < node.threshold) {
if(dist - _tau <= node.threshold) { // if there can still be neighbors inside the ball, recursively search left child first
_tau = search(node.left, target, k, heap, _tau);
}
if(dist + _tau >= node.threshold) { // if there can still be neighbors outside the ball, recursively search right child
_tau = search(node.right, target, k, heap, _tau);
}
// If the target lies outside the radius of the ball
} else {
if(dist + _tau >= node.threshold) { // if there can still be neighbors outside the ball, recursively search right child first
_tau = search(node.right, target, k, heap, _tau);
}
if (dist - _tau <= node.threshold) { // if there can still be neighbors inside the ball, recursively search left child
_tau = search(node.left, target, k, heap, _tau);
}
}
return _tau;
}
}
}