package com.jujutsu.tsne.barneshut; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.PriorityQueue; import java.util.Queue; import java.util.concurrent.Callable; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.Future; public class ParallelVpTree<StorageType> extends VpTree<StorageType> { private ForkJoinPool searcherPool; public ParallelVpTree(ForkJoinPool pool, Distance distance) { super(distance); searcherPool = pool; } public ParallelVpTree(ForkJoinPool pool) { searcherPool = pool; } public List<Future<ParallelTreeNode.TreeSearchResult>> searchMultiple(ParallelVpTree<StorageType> tree, DataPoint [] targets, int k) { List<ParallelTreeNode.ParallelTreeSearcher> searchers = new ArrayList<>(); for(int n = 0; n < targets.length; n++) { @SuppressWarnings("unchecked") ParallelTreeNode node = (ParallelTreeNode) tree.getRoot(); searchers.add(node.new ParallelTreeSearcher(node,_items,targets[n], k, n)); } List<Future<ParallelTreeNode.TreeSearchResult>> results = searcherPool.invokeAll(searchers); return results; } @Override protected VpTree<StorageType>.Node createNode() { return new ParallelTreeNode(); } class ParallelTreeNode extends VpTree<StorageType>.Node { class TreeSearchResult { int n; List<Double> distances; List<DataPoint> indices; public TreeSearchResult(List<DataPoint> indices, List<Double> distances, int n) { this.indices = indices; this.distances = distances; this.n = n; } public List<DataPoint> getIndices() { return indices; } public List<Double> getDistances() { return distances; } public int getIndex() { return n; } } class ParallelTreeSearcher implements Callable<TreeSearchResult> { Node node; Queue<HeapItem> heap; DataPoint target; int k; int n; DataPoint [] items; public ParallelTreeSearcher(Node tree, DataPoint [] items, DataPoint target, int k, int n) { this.node = tree; this.target = target; this.k = k; this.items = items; this.n = n; } @Override public TreeSearchResult call() { List<DataPoint> indices = new ArrayList<>(); List<Double> distances = new ArrayList<>(); PriorityQueue<HeapItem> heap = new PriorityQueue<HeapItem>(k,new Comparator<HeapItem>() { @Override public int compare(HeapItem o1, HeapItem o2) { return -1 * o1.compareTo(o2); } }); double tau = Double.MAX_VALUE; // Perform the search node.search(node, target, k, heap, tau); // Gather final results while(!heap.isEmpty()) { indices.add(items[heap.peek().index]); distances.add(heap.peek().dist); heap.remove(); } // Results are in reverse order Collections.reverse(indices); Collections.reverse(distances); return new TreeSearchResult(indices, distances,n); } } } }