package edu.stanford.rsl.tutorial.iterative; import ij.ImageJ; import edu.stanford.rsl.conrad.data.numeric.Grid3D; import edu.stanford.rsl.conrad.phantom.NumericalSheppLogan3D; public class GDTest { private static int GD_STEPSIZE_MAXITER = 20; private static float STEPSIZE_FACTOR = 0.9f; /** * @param args */ public static void main(String[] args) { int iter = 60; float stepsize = 0.3f; float regul = (float) Math.pow(10, -4); int size = 100; new ImageJ(); // Grid3D in = randGrid(size,size,size); Grid3D in = new NumericalSheppLogan3D(size,size,size).getNumericalSheppLoganPhantom(); in.show("In Volume"); Grid3D out = null; try { out = iterateGradientDescent(in, iter, stepsize, regul); } catch (Exception e) { e.printStackTrace(); } out.show("GD Out Volume"); } private static Grid3D randGrid(int size, int size2, int size3) { Grid3D rand = new Grid3D(size,size,size); for(int x=0; x<size; ++x) for(int y=0; y<size; ++y) for(int z=0; z<size; ++z) rand.setAtIndex(x,y,z, (float) (Math.random() * 10)); return rand; } private static Grid3D iterateGradientDescent(Grid3D myVolIn, int iter, float stepsize, float regul) throws Exception { boolean reduceStepsizeAndUpdate = false; Grid3D myVol = new Grid3D(myVolIn); Grid3D gradX, gradY, gradZ; for (int i = 0; i < iter; ++i) { if (stepsize < Math.pow(10, -9)) break; // G boolean offsetLeft = true; gradX = GridOp.sub(myVol, myVol, -1, 0, 0, offsetLeft); gradY = GridOp.sub(myVol, myVol, 0, -1, 0, offsetLeft); gradZ = GridOp.sub(myVol, myVol, 0, 0, -1, offsetLeft); int numNegVol = GridOp.numNeg(myVol); // gradMagnitude = sqrt(sum(G.^2,4) + regularization.^2) Grid3D gradMag = GridOp.add(GridOp.square(gradX), GridOp.square(gradY), GridOp.square(gradZ), regul * regul); GridOp.sqrtInPlace(gradMag); double tvNorm = GridOp.l1Norm(gradMag); // upd = divergence(G ./ gradMagnitude) * stepsize; // volN = myVol + upd // normalized gradients GridOp.divInPlace(gradX, gradMag); GridOp.divInPlace(gradY, gradMag); GridOp.divInPlace(gradZ, gradMag); offsetLeft = false; // offsetRight Grid3D gradXTmp = GridOp.sub(gradX, gradX, 1, 0, 0, offsetLeft); // fx(1,:,:) = Px(1,:,:); // fx(end,:,:) = -Px(end-1,:,:); //float[][][] b = gradXTmp.getBuffer(); for(int e=0; e<gradXTmp.getSize()[1]; ++e) for(int f=0; f<gradXTmp.getSize()[2]; ++f) gradXTmp.setAtIndex(0,e,f, gradX.getAtIndex(0,e,f)); for(int e=0; e<gradXTmp.getSize()[1]; ++e) for(int f=0; f<gradXTmp.getSize()[2]; ++f) gradXTmp.setAtIndex(gradXTmp.getSize()[0]-1,e,f, -gradX.getAtIndex(gradXTmp.getSize()[0]-2,e,f)); Grid3D gradYTmp = GridOp.sub(gradY, gradY, 0, 1, 0, offsetLeft); // fy(:,1,:) = Py(:,1,:); // fy(:,end,:) = -Py(:,end-1,:); //b = gradYTmp.getBuffer(); for(int e=0; e<gradYTmp.getSize()[0]; ++e) for(int f=0; f<gradYTmp.getSize()[2]; ++f) gradYTmp.setAtIndex(e,0,f, gradY.getAtIndex(e,0,f)); for(int e=0; e<gradYTmp.getSize()[0]; ++e) for(int f=0; f<gradYTmp.getSize()[2]; ++f) gradYTmp.setAtIndex(e,gradYTmp.getSize()[1]-1,f, -gradY.getAtIndex(e,gradYTmp.getSize()[1]-2,f)); Grid3D gradZTmp = GridOp.sub(gradZ, gradZ, 0, 0, 1, offsetLeft); // fz(:,:,1) = Pz(:,:,1); // fz(:,:,end) = -Pz(:,:,end-1); //b = gradZTmp.getBuffer(); for(int e=0; e<gradZTmp.getSize()[0]; ++e) for(int f=0; f<gradZTmp.getSize()[1]; ++f) gradZTmp.setAtIndex(e,f,0, gradZ.getAtIndex(e,f,0)); for(int e=0; e<gradZTmp.getSize()[0]; ++e) for(int f=0; f<gradZTmp.getSize()[1]; ++f) gradZTmp.setAtIndex(e,f,gradZTmp.getSize()[2]-1, -gradZ.getAtIndex(e,f,gradZTmp.getSize()[2]-2)); gradX = gradXTmp; gradY = gradYTmp; gradZ = gradZTmp; Grid3D upd = GridOp.add(gradX, gradY, gradZ); // divergence GridOp.mulInPlace(upd, stepsize); Grid3D volN = GridOp.add(myVol, upd); for(int ii=0; 0 == ii || (reduceStepsizeAndUpdate && ii<GD_STEPSIZE_MAXITER) ; ++ii){ int numNegVolN = GridOp.numNeg(volN); if(numNegVolN > numNegVol){ GridOp.mulInPlace(upd, STEPSIZE_FACTOR ); stepsize *= STEPSIZE_FACTOR; volN = GridOp.add(myVol, upd); reduceStepsizeAndUpdate = true; continue; } // check stepsize // GN offsetLeft = true; gradX = GridOp.sub(volN, volN, -1, 0, 0, offsetLeft); gradY = GridOp.sub(volN, volN, 0, -1, 0, offsetLeft); gradZ = GridOp.sub(volN, volN, 0, 0, -1, offsetLeft); // gradMagnitudeN = sqrt(sum(GN.^2,4) + regularization.^2) gradMag = GridOp.add(GridOp.square(gradX), GridOp.square(gradY), GridOp.square(gradZ), regul * regul); GridOp.sqrtInPlace(gradMag); double tvNormN = GridOp.l1Norm(gradMag); if (tvNormN > tvNorm) { GridOp.mulInPlace(upd, STEPSIZE_FACTOR); stepsize *= STEPSIZE_FACTOR; volN = GridOp.add(myVol, upd); tvNorm = tvNormN; reduceStepsizeAndUpdate = true; continue; } reduceStepsizeAndUpdate = false; } myVol = volN; } return myVol; } }