package com.jujutsu.tsne;
import static com.jujutsu.utils.EjmlOps.addRowVector;
import static com.jujutsu.utils.EjmlOps.assignAllLessThan;
import static com.jujutsu.utils.EjmlOps.assignAtIndex;
import static com.jujutsu.utils.EjmlOps.biggerThan;
import static com.jujutsu.utils.EjmlOps.colMean;
import static com.jujutsu.utils.EjmlOps.maximize;
import static com.jujutsu.utils.EjmlOps.replaceNaN;
import static com.jujutsu.utils.EjmlOps.setData;
import static com.jujutsu.utils.EjmlOps.setDiag;
import static com.jujutsu.utils.EjmlOps.tile;
import static com.jujutsu.utils.MatrixOps.abs;
import static com.jujutsu.utils.MatrixOps.equal;
import static com.jujutsu.utils.MatrixOps.fillMatrix;
import static com.jujutsu.utils.MatrixOps.negate;
import static com.jujutsu.utils.MatrixOps.range;
import static com.jujutsu.utils.MatrixOps.rnorm;
import static org.ejml.ops.CommonOps.add;
import static org.ejml.ops.CommonOps.addEquals;
import static org.ejml.ops.CommonOps.divide;
import static org.ejml.ops.CommonOps.elementDiv;
import static org.ejml.ops.CommonOps.elementLog;
import static org.ejml.ops.CommonOps.elementMult;
import static org.ejml.ops.CommonOps.elementPower;
import static org.ejml.ops.CommonOps.elementSum;
import static org.ejml.ops.CommonOps.mult;
import static org.ejml.ops.CommonOps.multAddTransB;
import static org.ejml.ops.CommonOps.scale;
import static org.ejml.ops.CommonOps.subtract;
import static org.ejml.ops.CommonOps.subtractEquals;
import static org.ejml.ops.CommonOps.sumRows;
import static org.ejml.ops.CommonOps.transpose;
import org.ejml.data.DenseMatrix64F;
import com.jujutsu.tsne.barneshut.TSneConfiguration;
import com.jujutsu.utils.MatrixOps;
/**
*
* Author: Leif Jonsson (leif.jonsson@gmail.com)
*
* This is a Java implementation of van der Maaten and Hintons t-sne
* dimensionality reduction technique that is particularly well suited
* for the visualization of high-dimensional datasets
*
*/
public class MemOptimizedTSne extends FastTSne {
@Override
public double [][] tsne(TSneConfiguration config) {
double[][] X = config.getXin();
int no_dims = config.getOutputDims();
int initial_dims = config.getInitialDims();
double perplexity = config.getPerplexity();
int max_iter = config.getMaxIter();
boolean use_pca = config.usePca();
String IMPLEMENTATION_NAME = this.getClass().getSimpleName();
System.out.println("X:Shape is = " + X.length + " x " + X[0].length);
System.out.println("Running " + IMPLEMENTATION_NAME + ".");
long end = System.currentTimeMillis();
long start = System.currentTimeMillis();
// Initialize variables
if(use_pca && X[0].length > initial_dims && initial_dims > 0) {
PrincipalComponentAnalysis pca = new PrincipalComponentAnalysis();
X = pca.pca(X, initial_dims);
System.out.println("X:Shape after PCA is = " + X.length + " x " + X[0].length);
System.out.println(MatrixOps.doubleArrayToPrintString(X,10,10));
}
int n = X.length;
double momentum = .5;
double initial_momentum = 0.5;
double final_momentum = 0.8;
int eta = 500;
double min_gain = 0.01;
DenseMatrix64F Y = new DenseMatrix64F(rnorm(n,no_dims));
DenseMatrix64F Ysqlmul = new DenseMatrix64F(Y.numRows,Y.numRows); // Ysqlmul = n x n
DenseMatrix64F dY = new DenseMatrix64F(fillMatrix(n,no_dims,0.0));
DenseMatrix64F iY = new DenseMatrix64F(fillMatrix(n,no_dims,0.0));
DenseMatrix64F gains = new DenseMatrix64F(fillMatrix(n,no_dims,1.0));
DenseMatrix64F btNeg = new DenseMatrix64F(n,no_dims);
DenseMatrix64F bt = new DenseMatrix64F(n,no_dims);
// Compute P-values
DenseMatrix64F P = new DenseMatrix64F(x2p(X, 1e-5, perplexity).P); // P = n x n
DenseMatrix64F Psized = new DenseMatrix64F(P.numRows,P.numCols); // L = n x n
DenseMatrix64F diag = new DenseMatrix64F(fillMatrix(Psized.numRows,Psized.numCols,0.0));
transpose(P,Psized);
addEquals(P,Psized);
divide(P ,elementSum(P));
replaceNaN(P,Double.MIN_VALUE);
scale(4.0,P); // early exaggeration
maximize(P, 1e-12);
System.out.println("Using perplexity: " + perplexity);
System.out.println("Y:Shape is = " + Y.getNumRows() + " x " + Y.getNumCols());
DenseMatrix64F sqed = new DenseMatrix64F(Y.numRows,Y.numCols); // sqed = n x n
DenseMatrix64F sum_Y = new DenseMatrix64F(1,Y.numRows);
DenseMatrix64F Q = new DenseMatrix64F(P.numRows,P.numCols); // Q = n x n
for (int iter = 0; iter < max_iter; iter++) {
// Compute pairwise affinities
elementPower(Y, 2, sqed);
sumRows(sqed, sum_Y);
multAddTransB(-2.0, Y, Y, Ysqlmul);
addRowVector(Ysqlmul, sum_Y);
transpose(Ysqlmul);
addRowVector(Ysqlmul, sum_Y);
add(Ysqlmul, 1.0);
divide(1.0,Ysqlmul);
assignAtIndex(Ysqlmul, range(n), range(n), 0);
divide(Ysqlmul , elementSum(Ysqlmul), Q);
maximize(Q, 1e-12);
// Compute gradient
subtract(P, Q, Psized);
elementMult(Psized, Ysqlmul);
DenseMatrix64F rowsum = sumRows(Psized,null); // rowsum = nx1
double [] rsum = new double[rowsum.numRows];
for (int i = 0; i < rsum.length; i++) {
rsum[i] = rowsum.get(i,0);
}
setDiag(diag,rsum);
subtract(diag, Psized, Psized);
mult(Psized, Y, dY);
scale(4.0, dY);
// Perform the update
if (iter < 20)
momentum = initial_momentum;
else
momentum = final_momentum;
boolean [][] boolMtrx = equal(biggerThan(dY,0.0),biggerThan(iY,0.0));
setData(btNeg, abs(negate(boolMtrx)));
setData(bt, abs(boolMtrx));
DenseMatrix64F gainsSmall = new DenseMatrix64F(gains);
DenseMatrix64F gainsBig = new DenseMatrix64F(gains);
add(gainsSmall,0.2);
scale(0.8,gainsBig);
elementMult(gainsSmall, btNeg);
elementMult(gainsBig, bt);
add(gainsSmall,gainsBig,gains);
assignAllLessThan(gains, min_gain, min_gain);
scale(momentum,iY);
DenseMatrix64F gainsdY = new DenseMatrix64F(gains.numRows,dY.numCols);
elementMult(gains , dY, gainsdY);
scale(eta,gainsdY);
subtractEquals(iY , gainsdY);
addEquals(Y , iY);
DenseMatrix64F colMeanY = colMean(Y, 0);
DenseMatrix64F meanTile = tile(colMeanY, n, 1);
subtractEquals(Y , meanTile);
// Compute current value of the cost function
if (iter % 50 == 0) {
DenseMatrix64F Pdiv = new DenseMatrix64F(P);
elementDiv(Pdiv , Q);
elementLog(Pdiv,Psized);
replaceNaN(Psized,Double.MIN_VALUE);
elementMult(Psized,P);
replaceNaN(Psized,Double.MIN_VALUE);
double C = elementSum(Psized);
end = System.currentTimeMillis();
System.out.printf("Iteration %d: error is %f (50 iterations in %4.2f seconds)\n", iter, C, (end - start) / 1000.0);
if(C < 0) {
System.err.println("Warning: Error is negative, this is usually a very bad sign!");
}
start = System.currentTimeMillis();
} else if(iter % 10 == 0) {
end = System.currentTimeMillis();
System.out.printf("Iteration %d: (10 iterations in %4.2f seconds)\n", iter, (end - start) / 1000.0);
start = System.currentTimeMillis();
}
// Stop lying about P-values
if (iter == 100)
divide(P , 4);
}
// Return solution
return MatrixOps.extractDoubleArray(Y);
}
}