package edu.stanford.rsl.tutorial.iterative;
import ij.ImageJ;
import edu.stanford.rsl.conrad.data.numeric.Grid3D;
import edu.stanford.rsl.conrad.data.numeric.opencl.OpenCLGrid3D;
import edu.stanford.rsl.conrad.geometry.trajectories.Trajectory;
import edu.stanford.rsl.conrad.phantom.NumericalSheppLogan3D;
import edu.stanford.rsl.conrad.utils.Configuration;
import edu.stanford.rsl.tutorial.cone.ConeBeamProjector;
/**
* iTV and eTV reconstruction
*
* @author Mario Amrehn
* GPU Implementation: Daniel Stromer
*
*/
public class Etv extends SartCL {
// -----------------------------------------
protected boolean debugEtv = false;
// -----------------------------------------
private float omega, regul;
private int gdIter;
private float stepsize;
private final float STEPSIZE_FACTOR = 0.9f;
private int GD_STEPSIZE_MAXITER = 20;
private double DYN_GD_MAX_LAMBDA = 1.2;//Double.MAX_VALUE;
private int DYN_GD_MAX_ITER = 5;
private long timeStart;
private long timeEnd;
//grids for gradient descent iterations
private OpenCLGrid3D gdVol = null;
private OpenCLGrid3D gradX;
private OpenCLGrid3D gradY;
private OpenCLGrid3D gradZ;
private OpenCLGrid3D gradMag;
private OpenCLGrid3D gradMagY;
private OpenCLGrid3D gradMagZ;
private OpenCLGrid3D divX;
private OpenCLGrid3D divY;
private OpenCLGrid3D divZ;
private OpenCLGrid3D upd;
private OpenCLGrid3D volN;
// grids for linear combinations
private OpenCLGrid3D sartSino;
private OpenCLGrid3D gdSino;
private OpenCLGrid3D diffSino;
private OpenCLGrid3D varGrid;
private OpenCLGrid3D sartSino2;
private OpenCLGrid3D projError;
public Etv(int[] volDims, double[] spacing, double[] origin, OpenCLGrid3D oProj, float beta, float omega, int gdIter, float regul, float initStepsize, double dynGD) throws Exception {
this(volDims, spacing, origin, oProj, beta, omega, gdIter, regul, initStepsize);
DYN_GD_MAX_LAMBDA = dynGD;
}
public Etv(int[] volDims, double[] spacing, double[] origin, OpenCLGrid3D oProj, float beta, float omega, int gdIter, float regul, float initStepsize) throws Exception {
super(volDims, spacing, origin, oProj, beta);
this.omega = omega;
this.gdIter = gdIter;
this.regul = regul;
this.stepsize = initStepsize;
}
public Etv(Grid3D initialVol, Grid3D sino, float beta, float omega, int gdIter, float regul, float initStepsize) throws Exception {
super(initialVol, sino, beta);
this.omega = omega;
this.gdIter = gdIter;
this.regul = regul;
this.stepsize = initStepsize;
}
public final void iterateETV() throws Exception{
iterateETV(1);
}
private void configure(){
// queue for opencl
gdVol = new OpenCLGrid3D(new Grid3D(imgSize[0],imgSize[1],imgSize[2]));
//gradient descent variables;
gradX = new OpenCLGrid3D(new Grid3D(imgSize[0],imgSize[1],imgSize[2]));
gradY = new OpenCLGrid3D(new Grid3D(imgSize[0],imgSize[1],imgSize[2]));
gradZ = new OpenCLGrid3D(new Grid3D(imgSize[0],imgSize[1],imgSize[2]));
gradMag = new OpenCLGrid3D(new Grid3D(imgSize[0],imgSize[1],imgSize[2]));
gradMagY = new OpenCLGrid3D(new Grid3D(imgSize[0],imgSize[1],imgSize[2]));
gradMagZ = new OpenCLGrid3D(new Grid3D(imgSize[0],imgSize[1],imgSize[2]));
divX = new OpenCLGrid3D(new Grid3D(imgSize[0],imgSize[1],imgSize[2]));
divY = new OpenCLGrid3D(new Grid3D(imgSize[0],imgSize[1],imgSize[2]));
divZ = new OpenCLGrid3D(new Grid3D(imgSize[0],imgSize[1],imgSize[2]));
volN = new OpenCLGrid3D(new Grid3D(imgSize[0],imgSize[1],imgSize[2]));
upd = new OpenCLGrid3D(new Grid3D(imgSize[0],imgSize[1],imgSize[2]));
//linear combination
sartSino = new OpenCLGrid3D(new Grid3D(width,height,maxProjs));
gdSino = new OpenCLGrid3D(new Grid3D(width,height,maxProjs));
diffSino = new OpenCLGrid3D(new Grid3D(width,height,maxProjs));
varGrid = new OpenCLGrid3D(new Grid3D(width,height,maxProjs));
sartSino2 = new OpenCLGrid3D(new Grid3D(width,height,maxProjs));
projError = new OpenCLGrid3D(new Grid3D(width,height,maxProjs));
//prepare for gpu calculations
volCL.getDelegate().prepareForDeviceOperation();
gdVol.getDelegate().prepareForDeviceOperation();
gradX.getDelegate().prepareForDeviceOperation();
gradY.getDelegate().prepareForDeviceOperation();
gradZ.getDelegate().prepareForDeviceOperation();
gradMag.getDelegate().prepareForDeviceOperation();
gradMagY.getDelegate().prepareForDeviceOperation();
gradMagZ.getDelegate().prepareForDeviceOperation();
divX.getDelegate().prepareForDeviceOperation();
divY.getDelegate().prepareForDeviceOperation();
divZ.getDelegate().prepareForDeviceOperation();
upd.getDelegate().prepareForDeviceOperation();
volN.getDelegate().prepareForDeviceOperation();
sartSino.getDelegate().prepareForDeviceOperation();
gdSino.getDelegate().prepareForDeviceOperation();
oProj.getDelegate().prepareForDeviceOperation();
varGrid.getDelegate().prepareForDeviceOperation();
sartSino2.getDelegate().prepareForDeviceOperation();
diffSino.getDelegate().prepareForDeviceOperation();
projError.getDelegate().prepareForDeviceOperation();
}
public void iterateETV(int iter) throws Exception {
configure();
for(int i=0; i<iter; ++i){
if(debugEtv) timeStart = System.currentTimeMillis();
queueIterative.putCopyBuffer(volCL.getDelegate().getCLBuffer(), gdVol.getDelegate().getCLBuffer()).finish();
gdVol.getDelegate().notifyDeviceChange();
//SART iteration
super.iterate();
iterateGradientDescent(gdVol,gdIter);
float lambda = (float)getLinCombination(volCL, gdVol);
for(int ii=0; lambda>DYN_GD_MAX_LAMBDA && ii<DYN_GD_MAX_ITER; ++ii){
iterateGradientDescent(gdVol,gdIter);
lambda = (float)getLinCombination(volCL, gdVol);
}
if(Double.isNaN(lambda) || Double.isInfinite(lambda)){
continue; // if vol is empty use SART only
} else if(1 < lambda)
lambda = 1;
/* volume update */
// NOTE: (1-L)*vol + L*gdVol
volCL.getGridOperator().multiplyBy(volCL, (float)(1.0-lambda));
gdVol.getGridOperator().multiplyBy(gdVol, (lambda));
volCL.getGridOperator().addBySave(volCL, gdVol);
if(debugEtv) {
timeEnd = System.currentTimeMillis() - timeStart;
System.out.format("Time gradient descent: %.5f seconds\n", ((double) timeEnd) / 1000.0);
System.out.println(volCL.getGridOperator().normL1(volCL));
}
}
volumeResult = new Grid3D(volCL);
volumeResult.setOrigin(-geo.getOriginInPixelsX(),-geo.getOriginInPixelsY(),-geo.getOriginInPixelsZ());
volumeResult.setSpacing(geo.getVoxelSpacingX(),geo.getVoxelSpacingY(),geo.getVoxelSpacingZ());
unloadEtv();
}
private double getLinCombination(final OpenCLGrid3D volCL, final OpenCLGrid3D gdVol) throws Exception {
//creating forward projections of the volumes that have to be compared
cbp.fastProjectRayDrivenCL(sartSino,volCL);
cbp.fastProjectRayDrivenCL(gdSino, gdVol);
sartSino.getGridOperator().fillInvalidValues(sartSino);
gdSino.getGridOperator().fillInvalidValues(gdSino);
double definedError = (1-omega)*getProjError(sartSino) + omega*getProjError(gdSino);
//(oProj - sartSino).*diffSino
queueIterative.putCopyBuffer(gdSino.getDelegate().getCLBuffer(), diffSino.getDelegate().getCLBuffer())
.putCopyBuffer(oProj.getDelegate().getCLBuffer(), varGrid.getDelegate().getCLBuffer())
.finish();
diffSino.getDelegate().notifyDeviceChange();
varGrid.getDelegate().notifyDeviceChange();
diffSino.getGridOperator().subtractBy(diffSino, sartSino);
varGrid.getGridOperator().subtractBy(varGrid, sartSino);
varGrid.getGridOperator().multiplyBy(varGrid, diffSino);
double a = varGrid.getGridOperator().sum(varGrid);
queueIterative.putCopyBuffer(sartSino.getDelegate().getCLBuffer(), varGrid.getDelegate().getCLBuffer())
.putCopyBuffer(sartSino.getDelegate().getCLBuffer(), sartSino2.getDelegate().getCLBuffer()).finish();
varGrid.getDelegate().notifyDeviceChange();
sartSino2.getDelegate().notifyDeviceChange();
varGrid.getGridOperator().multiplyBy(varGrid, 2);
varGrid.getGridOperator().subtractBy(varGrid, gdSino);
varGrid.getGridOperator().multiplyBy(varGrid, gdSino);
sartSino2.getGridOperator().pow(sartSino2, 2);
varGrid.getGridOperator().subtractBy(varGrid, sartSino2);
double b = varGrid.getGridOperator().sum(varGrid);
queueIterative.putCopyBuffer(sartSino.getDelegate().getCLBuffer(), varGrid.getDelegate().getCLBuffer()).finish();
varGrid.getDelegate().notifyDeviceChange();
varGrid.getGridOperator().multiplyBy(varGrid, 2);
varGrid.getGridOperator().subtractBy(varGrid, oProj);
varGrid.getGridOperator().multiplyBySave(varGrid, oProj);
varGrid.getGridOperator().subtractBy(varGrid, sartSino2);
double c = varGrid.getGridOperator().sum(varGrid);
return (Math.sqrt(a*a -b*(definedError + c)) + a) / (-b);
}
private double getProjError(OpenCLGrid3D sino) throws Exception {
queueIterative.putCopyBuffer(oProj.getDelegate().getCLBuffer(), projError.getDelegate().getCLBuffer()).finish();
projError.getDelegate().notifyDeviceChange();
projError.getGridOperator().subtractBySave(projError, sino);
projError.getGridOperator().pow(projError, 2);
return projError.getGridOperator().sum(projError);
}
private void iterateGradientDescent(OpenCLGrid3D gdVol, int iter) throws Exception {
boolean reduceStepsizeAndUpdate = false;
for (int i = 0; i < iter; ++i) {
if (stepsize < Math.pow(10, -9))
break;
// G
boolean offsetleft = true;
queueIterative .putCopyBuffer(gdVol.getDelegate().getCLBuffer(), gradX.getDelegate().getCLBuffer())
.putCopyBuffer(gdVol.getDelegate().getCLBuffer(), gradY.getDelegate().getCLBuffer())
.putCopyBuffer(gdVol.getDelegate().getCLBuffer(), gradZ.getDelegate().getCLBuffer()).finish();
gradX.getDelegate().notifyDeviceChange();
gradY.getDelegate().notifyDeviceChange();
gradZ.getDelegate().notifyDeviceChange();
gradX.getGridOperator().gradX(gradX, gdVol,-1,offsetleft);
gradY.getGridOperator().gradY(gradY, gdVol,-1,offsetleft);
gradZ.getGridOperator().gradZ(gradZ, gdVol,-1,offsetleft);
//not yet implemented in CL
int numNegVol = gdVol.getGridOperator().countNegativeElements(gdVol);
// gradMagnitude = sqrt(sum(G.^2,4) + regularization.^2)
queueIterative.putCopyBuffer(gradX.getDelegate().getCLBuffer(), gradMag.getDelegate().getCLBuffer())
.putCopyBuffer(gradY.getDelegate().getCLBuffer(), gradMagY.getDelegate().getCLBuffer())
.putCopyBuffer(gradZ.getDelegate().getCLBuffer(), gradMagZ.getDelegate().getCLBuffer()).finish();
gradMag.getDelegate().notifyDeviceChange();
gradMagY.getDelegate().notifyDeviceChange();
gradMagZ.getDelegate().notifyDeviceChange();
gradMag.getGridOperator().pow(gradMag, 2);
gradMagY.getGridOperator().pow(gradMagY, 2);
gradMagZ.getGridOperator().pow(gradMagZ, 2);
gradMag.getGridOperator().addBy(gradMag, gradMagY);
gradMag.getGridOperator().addBy(gradMag, gradMagZ);
gradMag.getGridOperator().addBy(gradMag, regul * regul);
gradMag.getGridOperator().pow(gradMag, 0.5);
double tvNorm = gradMag.getGridOperator().normL1(gradMag); // double tvNorm = GridOp.l1Norm(gradMag);
// upd = divergence(G ./ gradMagnitude) * stepsize;
// volN = myVol + upd
// normalized gradients
gradX.getGridOperator().divideBySave(gradX, gradMag);
gradY.getGridOperator().divideBySave(gradY, gradMag);
gradZ.getGridOperator().divideBySave(gradZ, gradMag);
offsetleft = false; // offsetRight
int offsetvalue = 1;
// gradMagnitude = sqrt(sum(G.^2,4) + regularization.^2)
queueIterative.putCopyBuffer(gradX.getDelegate().getCLBuffer(), divX.getDelegate().getCLBuffer())
.putCopyBuffer(gradY.getDelegate().getCLBuffer(), divY.getDelegate().getCLBuffer())
.putCopyBuffer(gradZ.getDelegate().getCLBuffer(), divZ.getDelegate().getCLBuffer()).finish();
divX.getDelegate().notifyDeviceChange();
divY.getDelegate().notifyDeviceChange();
divZ.getDelegate().notifyDeviceChange();
gradX.getGridOperator().divergence(gradX,divX,offsetvalue,0,0,offsetleft);
gradY.getGridOperator().divergence(gradY,divY,0,offsetvalue,0,offsetleft);
gradZ.getGridOperator().divergence(gradZ,divZ,0,0,offsetvalue,offsetleft);
queueIterative.putCopyBuffer(gradX.getDelegate().getCLBuffer(), upd.getDelegate().getCLBuffer()).finish();
gradX.getDelegate().notifyDeviceChange();
upd.getGridOperator().addBySave(upd, gradY);
upd.getGridOperator().addBySave(upd, gradZ);
upd.getGridOperator().multiplyBy(upd, stepsize);
queueIterative.putCopyBuffer(gdVol.getDelegate().getCLBuffer(), volN.getDelegate().getCLBuffer()).finish();
volN.getDelegate().notifyDeviceChange();
volN.getGridOperator().addBy(volN, upd);
for(int ii=0; 0 == ii || (reduceStepsizeAndUpdate && ii<GD_STEPSIZE_MAXITER) ; ++ii){
int numNegVolN = volN.getGridOperator().countNegativeElements(volN);
if(numNegVolN > numNegVol){
upd.getGridOperator().multiplyBy(upd, STEPSIZE_FACTOR);
stepsize *= STEPSIZE_FACTOR;
queueIterative.putCopyBuffer(gdVol.getDelegate().getCLBuffer(), volN.getDelegate().getCLBuffer()).finish();
volN.getDelegate().notifyDeviceChange();
volN.getGridOperator().addBy(volN, upd);
reduceStepsizeAndUpdate = true;
continue;
}
// check stepsize
// GN
offsetleft = true;
queueIterative.putCopyBuffer(volN.getDelegate().getCLBuffer(), gradX.getDelegate().getCLBuffer())
.putCopyBuffer(volN.getDelegate().getCLBuffer(), gradY.getDelegate().getCLBuffer())
.putCopyBuffer(volN.getDelegate().getCLBuffer(), gradZ.getDelegate().getCLBuffer()).finish();
gradX.getDelegate().notifyDeviceChange();
gradY.getDelegate().notifyDeviceChange();
gradZ.getDelegate().notifyDeviceChange();
gradX.getGridOperator().gradX(gradX, volN,-1,offsetleft);
gradY.getGridOperator().gradY(gradY, volN,-1,offsetleft);
gradZ.getGridOperator().gradZ(gradZ, volN,-1,offsetleft);
// gradMagnitudeN = sqrt(sum(GN.^2,4) + regularization.^2)
queueIterative.putCopyBuffer(gradX.getDelegate().getCLBuffer(), gradMag.getDelegate().getCLBuffer())
.putCopyBuffer(gradY.getDelegate().getCLBuffer(), gradMagY.getDelegate().getCLBuffer())
.putCopyBuffer(gradZ.getDelegate().getCLBuffer(), gradMagZ.getDelegate().getCLBuffer()).finish();
gradMag.getDelegate().notifyDeviceChange();
gradMagY.getDelegate().notifyDeviceChange();
gradMagZ.getDelegate().notifyDeviceChange();
gradMag.getGridOperator().pow(gradMag, 2);
gradMagY.getGridOperator().pow(gradMagY, 2);
gradMagZ.getGridOperator().pow(gradMagZ, 2);
gradMag.getGridOperator().addBySave(gradMag, gradMagY);
gradMag.getGridOperator().addBySave(gradMag, gradMagZ);
gradMag.getGridOperator().addBy(gradMag, regul * regul);
gradMag.getGridOperator().pow(gradMag, 0.5);
double tvNormN = gradMag.getGridOperator().normL1(gradMag);
if (tvNormN > tvNorm) {
upd.getGridOperator().multiplyBySave(upd, STEPSIZE_FACTOR);
stepsize *= STEPSIZE_FACTOR;
queueIterative.putCopyBuffer(gdVol.getDelegate().getCLBuffer(), volN.getDelegate().getCLBuffer()).finish();
volN.getDelegate().notifyDeviceChange();
volN.getGridOperator().addBySave(volN, upd);
tvNorm = tvNormN;
reduceStepsizeAndUpdate = true;
continue;
}
reduceStepsizeAndUpdate = false;
}
queueIterative.putCopyBuffer(volN.getDelegate().getCLBuffer(), gdVol.getDelegate().getCLBuffer()).finish();
gdVol.getDelegate().notifyDeviceChange();
}
}
private void unloadEtv(){
if(queueIterative != null && !queueIterative.isReleased()) queueIterative.release();
//grids for gradient descent iterations
if(gdVol != null ) gdVol.release();
if(gradX != null ) gradX.release();
if(gradY != null ) gradY.release();
if(gradZ != null ) gradZ.release();
if(gradMag != null) gradMag.release();
if(gradMagY != null) gradMagY.release();
if(gradMagZ != null) gradMagZ.release();
if(divX != null) divX.release();
if(divY != null) divY.release();
if(divZ != null) divZ.release();
if(upd != null) upd.release();
if(volN != null) volN.release();
// grids for linear combinations
if(sartSino != null) sartSino.release();
if(gdSino != null) gdSino.release();
if(diffSino != null) diffSino.release();
if(varGrid != null) varGrid.release();
if(sartSino2 != null) sartSino2.release();
if(projError != null) projError.release();
}
public static void main(String[] args){
Configuration.loadConfiguration();
Configuration conf = Configuration.getGlobalConfiguration();
Trajectory traj = conf.getGeometry();
new ImageJ();
ConeBeamProjector cbp = new ConeBeamProjector();
OpenCLGrid3D grid = new OpenCLGrid3D(new NumericalSheppLogan3D(traj.getReconDimensionX(),traj.getReconDimensionY(), traj.getReconDimensionZ()).getNumericalSheppLoganPhantom());
grid.setOrigin(-traj.getOriginInPixelsX(),-traj.getOriginInPixelsY(),-traj.getOriginInPixelsZ());
grid.setSpacing(traj.getVoxelSpacingX(),traj.getVoxelSpacingY(),traj.getVoxelSpacingZ());
try {
final float omega = 0.3f;
final int gdIter = 25;
final float regul = (float) Math.pow(10, -4);
final float initStepsize = 0.3f;
final int eTVIterations = 2;
final float beta = 0.8f;
OpenCLGrid3D sino = new OpenCLGrid3D(new Grid3D(traj.getDetectorWidth(),traj.getDetectorHeight(),traj.getProjectionStackSize()));
sino.setOrigin(0,0,0);
sino.setSpacing(1,1,1);
cbp.fastProjectRayDrivenCL(sino,grid);
System.out.println("GT: "+sino.getGridOperator().normL1(grid));
Etv etvtest = new Etv(grid.getSize(),grid.getSpacing(),grid.getOrigin(), sino, beta, omega, gdIter, regul,initStepsize);
long start = System.currentTimeMillis();
etvtest.iterateETV(eTVIterations);
long ende = System.currentTimeMillis()-start;
System.out.format("Time iTV: %.5f seconds\n", ((double) ende) / 1000.0);
Grid3D foo = etvtest.getVol();
System.out.println("L1: "+foo.getGridOperator().normL1(foo));
System.out.println("RMSE: "+foo.getGridOperator().rmse(foo,grid));
foo.show();
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}