package com.jujutsu.tsne.barneshut; import static java.lang.Math.exp; import static java.lang.Math.log; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.Future; import java.util.concurrent.RecursiveAction; import com.jujutsu.utils.MatrixOps; public class ParallelBHTsne extends BHTSne { private ForkJoinPool gradientPool; private ExecutorService gradientCalculationPool; class RecursiveGradientCalculator extends RecursiveAction { final static long serialVersionUID = 1L; int startRow = -1; int endRow = -1; int limit = 100; SPTree tree; double[][] neg_f; double theta; AtomicDouble sum_Q; public RecursiveGradientCalculator(SPTree tree, double [][] neg_f , double theta, AtomicDouble sum_Q, int startRow, int endRow, int ll) { this.limit = ll; this.startRow = startRow; this.endRow = endRow; this.tree = tree; this.neg_f = neg_f; this.theta = theta; this.sum_Q = sum_Q; } @Override protected void compute() { if ( (endRow-startRow) <= limit ) { for (int row = startRow; row < endRow; row++) { tree.computeNonEdgeForces(row, theta, neg_f[row], sum_Q); } } else { int range = (endRow-startRow); int startDoc1 = startRow; int endDoc1 = startRow + (range / 2); int startDoc2 = endDoc1; int endDoc2 = endRow; invokeAll(new RecursiveGradientCalculator(tree,neg_f, theta, sum_Q, startDoc1, endDoc1, limit), new RecursiveGradientCalculator(tree,neg_f, theta, sum_Q, startDoc2, endDoc2, limit)); } } } class ParallelGradientCalculator implements Callable<Double> { final static long serialVersionUID = 1L; int row = -1; int limit = 100; ParallelSPTree tree; double[][] neg_f; double theta; public ParallelGradientCalculator(ParallelSPTree tree, double [][] neg_f , double theta, int row, int ll) { this.limit = ll; this.row = row; this.tree = tree; this.neg_f = neg_f; this.theta = theta; } @Override public Double call() { return tree.computeNonEdgeForces(row, theta, neg_f[row], 0.0); } } @Override double[][] run(TSneConfiguration config) { gradientPool = new ForkJoinPool(Runtime.getRuntime().availableProcessors()); gradientCalculationPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); double [][] Y = super.run(config); gradientPool.shutdown(); gradientCalculationPool.shutdown(); return Y; } class RecursiveGradientUpdater extends RecursiveAction { final static long serialVersionUID = 1L; int startIdx = -1; int endIdx = -1; int limit = 100; int N; int no_dims; double[] Y; double momentum; double eta; double[] dY; double[] uY; double[] gains; public RecursiveGradientUpdater(int n, int no_dims, double[] Y, double momentum, double eta, double[] dY, double[] uY, double[] gains, int startIdx, int endIdx, int limit) { super(); this.startIdx = startIdx; this.endIdx = endIdx; this.limit = limit; N = n; this.no_dims = no_dims; this.Y = Y; this.momentum = momentum; this.eta = eta; this.dY = dY; this.uY = uY; this.gains = gains; } @Override protected void compute() { if ( (endIdx-startIdx) <= limit ) { for (int n = startIdx; n < endIdx; n++) { // Update gains gains[n] = (sign_tsne(dY[n]) != sign_tsne(uY[n])) ? (gains[n] + .2) : (gains[n] * .8); if(gains[n] < .01) gains[n] = .01; // Perform gradient update (with momentum and gains) Y[n] = Y[n] + uY[n]; uY[n] = momentum * uY[n] - eta * gains[n] * dY[n]; } } else { int range = (endIdx-startIdx); int startIdx1 = startIdx; int endIdx1 = startIdx + (range / 2); int endIdx2 = endIdx; invokeAll(new RecursiveGradientUpdater(N, no_dims, Y, momentum, eta, dY, uY, gains, startIdx1, endIdx1, limit), new RecursiveGradientUpdater(N, no_dims, Y, momentum, eta, dY, uY, gains, endIdx1, endIdx2, limit)); } } } @Override void updateGradient(int N, int no_dims, double[] Y, double momentum, double eta, double[] dY, double[] uY, double[] gains) { RecursiveGradientUpdater dslr = new RecursiveGradientUpdater(N, no_dims, Y, momentum, eta, dY, uY, gains,0,N * no_dims,N/(Runtime.getRuntime().availableProcessors()*10)); gradientPool.invoke(dslr); } // Compute gradient of the t-SNE cost function (using Barnes-Hut algorithm) @Override void computeGradient(double [] P, int [] inp_row_P, int [] inp_col_P, double [] inp_val_P, double [] Y, int N, int D, double [] dC, double theta) { // Construct space-partitioning tree on current map ParallelSPTree tree = new ParallelSPTree(D, Y, N); // Compute all terms required for t-SNE gradient double [] pos_f = new double[N * D]; double [][] neg_f = new double[N][D]; tree.computeEdgeForces(inp_row_P, inp_col_P, inp_val_P, N, pos_f); //AtomicDouble Asum_Q = new AtomicDouble(); double sum_Q = 0; //RecursiveGradientCalculator dslr = new RecursiveGradientCalculator(tree, neg_f, theta, sum_Q, 0, N, 20); //gradientPool.invoke(dslr); List<ParallelGradientCalculator> calculators = new ArrayList<>(); for(int n = 0; n < N; n++) { calculators.add(new ParallelGradientCalculator(tree, neg_f, theta, n, 20)); } List<Future<Double>> results; try { results = gradientCalculationPool.invokeAll(calculators); for (Future<Double> result : results) { double tmp = result.get(); sum_Q += tmp; } } catch (InterruptedException e) { e.printStackTrace(); System.exit(-1); } catch (ExecutionException e) { e.printStackTrace(); System.exit(-1); } //for(int n = 0; n < N; n++) tree.computeNonEdgeForces(n, theta, neg_f[n], sum_Q); // Compute final t-SNE gradient for(int i = 0; i < N; i++) { for(int j = 0; j < D; j++) { dC[i*D+j] = pos_f[i*D+j] - (neg_f[i][j] / sum_Q); } } } @Override // Compute input similarities with a fixed perplexity using ball trees void computeGaussianPerplexity(double [] X, int N, int D, int [] _row_P, int [] _col_P, double [] _val_P, double perplexity, int K) { if(perplexity > K) System.out.println("Perplexity should be lower than K!"); // Allocate the memory we need /**_row_P = (int*) malloc((N + 1) * sizeof(int)); *_col_P = (int*) calloc(N * K, sizeof(int)); *_val_P = (double*) calloc(N * K, sizeof(double)); if(*_row_P == null || *_col_P == null || *_val_P == null) { Rcpp::stop("Memory allocation failed!\n"); }*/ int [] row_P = _row_P; int [] col_P = _col_P; double [] val_P = _val_P; double [] cur_P = new double[N - 1]; row_P[0] = 0; for(int n = 0; n < N; n++) row_P[n + 1] = row_P[n] + K; // Build ball tree on data set ParallelVpTree<DataPoint> tree = new ParallelVpTree<DataPoint>(gradientPool,distance); final DataPoint [] obj_X = new DataPoint [N]; for(int n = 0; n < N; n++) { double [] row = MatrixOps.extractRowFromFlatMatrix(X,n,D); obj_X[n] = new DataPoint(D, n, row); } tree.create(obj_X); // VERIFIED THAT TREES LOOK THE SAME //System.out.println("Created Tree is: "); // AdditionalInfoProvider pp = new AdditionalInfoProvider() { // @Override // public String provideInfo(Node node) { // return "" + obj_X[node.index].index(); // } // }; // TreePrinter printer = new TreePrinter(pp); // printer.printTreeHorizontal(tree.getRoot()); // Loop over all points to find nearest neighbors List<Future<ParallelVpTree<DataPoint>.ParallelTreeNode.TreeSearchResult>> results = tree.searchMultiple(tree, obj_X, K+1); for (Future<ParallelVpTree<DataPoint>.ParallelTreeNode.TreeSearchResult> result : results) { ParallelVpTree<DataPoint>.ParallelTreeNode.TreeSearchResult res = null; List<Double> distances = null; List<DataPoint> indices = null; int n = -1; try { res = result.get(); distances = res.getDistances(); indices = res.getIndices(); n = res.getIndex(); } catch (InterruptedException | ExecutionException e) { e.printStackTrace(); } // Initialize some variables for binary search boolean found = false; double beta = 1.0; double min_beta = -Double.MAX_VALUE; double max_beta = Double.MAX_VALUE; double tol = 1e-5; // Iterate until we found a good perplexity int iter = 0; double sum_P = 0.; while(!found && iter < 200) { // Compute Gaussian kernel row and entropy of current row sum_P = Double.MIN_VALUE; double H = .0; for(int m = 0; m < K; m++) { cur_P[m] = exp(-beta * distances.get(m + 1)); sum_P += cur_P[m]; H += beta * (distances.get(m + 1) * cur_P[m]); } H = (H / sum_P) + log(sum_P); // Evaluate whether the entropy is within the tolerance level double Hdiff = H - log(perplexity); if(Hdiff < tol && -Hdiff < tol) { found = true; } else { if(Hdiff > 0) { min_beta = beta; if(max_beta == Double.MAX_VALUE || max_beta == -Double.MAX_VALUE) beta *= 2.0; else beta = (beta + max_beta) / 2.0; } else { max_beta = beta; if(min_beta == -Double.MAX_VALUE || min_beta == Double.MAX_VALUE) beta /= 2.0; else beta = (beta + min_beta) / 2.0; } } // Update iteration counter iter++; } // Row-normalize current row of P and store in matrix for(int m = 0; m < K; m++) { cur_P[m] /= sum_P; col_P[row_P[n] + m] = indices.get(m + 1).index(); val_P[row_P[n] + m] = cur_P[m]; } } } }