package com.jujutsu.tsne; import static com.jujutsu.utils.MatrixOps.addColumnVector; import static com.jujutsu.utils.MatrixOps.addRowVector; import static com.jujutsu.utils.MatrixOps.assignValuesToRow; import static com.jujutsu.utils.MatrixOps.concatenate; import static com.jujutsu.utils.MatrixOps.equal; import static com.jujutsu.utils.MatrixOps.fillMatrix; import static com.jujutsu.utils.MatrixOps.getValuesFromRow; import static com.jujutsu.utils.MatrixOps.mean; import static com.jujutsu.utils.MatrixOps.negate; import static com.jujutsu.utils.MatrixOps.range; import static com.jujutsu.utils.MatrixOps.scalarInverse; import static com.jujutsu.utils.MatrixOps.scalarMult; import static com.jujutsu.utils.MatrixOps.sqrt; import static com.jujutsu.utils.MatrixOps.square; import static com.jujutsu.utils.MatrixOps.sum; import static com.jujutsu.utils.MatrixOps.times; import org.jblas.DoubleMatrix; import com.jujutsu.tsne.barneshut.TSneConfiguration; import com.jujutsu.utils.BlasOps; import com.jujutsu.utils.MatrixOps; /** * * Author: Leif Jonsson (leif.jonsson@gmail.com) * * This is a port of van der Maaten and Hintons Python implementation of t-sne * */ public class BlasTSne implements TSne { MatrixOps mo = new MatrixOps(); protected volatile boolean abort = false; @Override public double [][] tsne(TSneConfiguration config) { double[][] X = config.getXin(); int no_dims = config.getOutputDims(); int initial_dims = config.getInitialDims(); double perplexity = config.getPerplexity(); int max_iter = config.getMaxIter(); boolean use_pca = config.usePca(); String IMPLEMENTATION_NAME = this.getClass().getSimpleName(); System.out.println("X:Shape is = " + X.length + " x " + X[0].length); System.out.println("Running " + IMPLEMENTATION_NAME + "."); // Initialize variables if(use_pca && X[0].length > initial_dims && initial_dims > 0) { PrincipalComponentAnalysis pca = new PrincipalComponentAnalysis(); X = pca.pca(X, initial_dims); //X = BlasOps.pca(new DoubleMatrix(X), initial_dims).toArray2(); System.out.println("X:Shape after PCA is = " + X.length + " x " + X[0].length); } int n = X.length; double momentum = .5; double initial_momentum = 0.5; double final_momentum = 0.8; int eta = 500; double min_gain = 0.01; DoubleMatrix Y = DoubleMatrix.randn(n,no_dims); DoubleMatrix dY = DoubleMatrix.zeros(n,no_dims); DoubleMatrix iY = DoubleMatrix.zeros(n,no_dims); DoubleMatrix gains = DoubleMatrix.ones(n,no_dims); // Compute P-values double [][] Pt = x2p(X, 1e-5, perplexity).P; DoubleMatrix P = new DoubleMatrix(Pt); P = P.add(P.transpose()); P = P.div(P.sum()); P = P.mul(4); // early exaggeration P = P.max(1e-12); System.out.println("Y:Shape is = " + Y.rows + " x " + Y.columns); // Run iterations for (int iter = 0; iter < max_iter && !abort; iter++) { // Compute pairwise affinities DoubleMatrix sum_Y = BlasOps.square(Y).rowSums().transpose(); DoubleMatrix num = BlasOps.scalarInverse(Y.mmul(Y.transpose()).mul(-2).addRowVector(sum_Y).transpose().addRowVector(sum_Y).add(1)); BlasOps.assignAtIndex(num, range(n), range(n), 0); DoubleMatrix Q = num.div(num.sum()); Q = Q.max(1e-12); // Compute gradient DoubleMatrix L = P.sub(Q).mul(num); dY = DoubleMatrix.diag(L.rowSums()).sub(L).mmul(Y).mul(4); // Perform the update if (iter < 20) momentum = initial_momentum; else momentum = final_momentum; DoubleMatrix gainsSmall = new DoubleMatrix(); gainsSmall.copy(gains); DoubleMatrix gainsBig = new DoubleMatrix(); gainsBig.copy(gains); gainsSmall = gainsSmall.add(0.2); gainsBig = gainsBig.mul(0.8); DoubleMatrix btNeg = BlasOps.abs(negate(equal(BlasOps.biggerThan(dY,0.0),BlasOps.biggerThan(iY,0.0)))); gainsSmall = gainsSmall.mul(btNeg); DoubleMatrix bt = BlasOps.abs( equal(BlasOps.biggerThan(dY,0.0),BlasOps.biggerThan(iY,0.0))); gainsBig = gainsBig.mul(bt); gains = gainsSmall.add(gainsBig); BlasOps.assignAllLessThan(gains, min_gain, min_gain); iY = iY.mul(momentum).sub(gains.mul(dY).mul(eta)); Y = Y.add(iY); Y = Y.sub(BlasOps.tile(Y.columnMeans(), n, 1)); // Compute current value of cost function if (iter % 100 == 0) { DoubleMatrix logdivide = BlasOps.log(P.div(Q)); logdivide = BlasOps.replaceNaN(logdivide,0); double C = P.mul(logdivide).sum(); System.out.println("Iteration " + iter + ": error is " + C); } else if(iter % 10 == 0) { System.out.println("Iteration " + iter); } // Stop lying about P-values if (iter == 100) P = P.div(4); } // Return solution return Y.toArray2(); } public R Hbeta (double [][] D, double beta){ DoubleMatrix Dd = new DoubleMatrix(D); DoubleMatrix P = BlasOps.exp(Dd.mul(-beta)); double sumP = P.sum(); // sumP confirmed scalar double H = Math.log(sumP) + beta * Dd.mul(P).sum() / sumP; P = P.div(sumP); R r = new R(); r.H = H; r.P = P.toArray2(); return r; } public R x2p(double [][] X,double tol, double perplexity){ int n = X.length; double [][] sum_X = sum(square(X), 1); double [][] times = scalarMult(times(X, mo.transpose(X)), -2); double [][] prodSum = addColumnVector(mo.transpose(times), sum_X); double [][] D = addRowVector(prodSum, mo.transpose(sum_X)); // D seems correct at this point compared to Python version double [][] P = fillMatrix(n,n,0.0); double [] beta = fillMatrix(n,n,1.0)[0]; double logU = Math.log(perplexity); System.out.println("Starting x2p..."); for (int i = 0; i < n; i++) { if (i % 500 == 0) System.out.println("Computing P-values for point " + i + " of " + n + "..."); double betamin = Double.NEGATIVE_INFINITY; double betamax = Double.POSITIVE_INFINITY; double [][] Di = getValuesFromRow(D, i,concatenate(range(0,i),range(i+1,n))); R hbeta = Hbeta(Di, beta[i]); double H = hbeta.H; double [][] thisP = hbeta.P; // Evaluate whether the perplexity is within tolerance double Hdiff = H - logU; int tries = 0; while(Math.abs(Hdiff) > tol && tries < 50){ if (Hdiff > 0){ betamin = beta[i]; if (Double.isInfinite(betamax)) beta[i] = beta[i] * 2; else beta[i] = (beta[i] + betamax) / 2; } else{ betamax = beta[i]; if (Double.isInfinite(betamin)) beta[i] = beta[i] / 2; else beta[i] = ( beta[i] + betamin) / 2; } hbeta = Hbeta(Di, beta[i]); H = hbeta.H; thisP = hbeta.P; Hdiff = H - logU; tries = tries + 1; } assignValuesToRow(P, i,concatenate(range(0,i),range(i+1,n)),thisP[0]); } R r = new R(); r.P = P; r.beta = beta; double sigma = mean(sqrt(scalarInverse(beta))); System.out.println("Mean value of sigma: " + sigma); return r; } @Override public void abort() { abort=true; } }