package edu.stanford.rsl.tutorial.iterative;
import ij.ImageJ;
import java.nio.FloatBuffer;
import com.jogamp.opencl.CLBuffer;
import com.jogamp.opencl.CLCommandQueue;
import com.jogamp.opencl.CLContext;
import edu.stanford.rsl.conrad.data.numeric.Grid2D;
import edu.stanford.rsl.conrad.data.numeric.Grid3D;
import edu.stanford.rsl.conrad.data.numeric.opencl.OpenCLGrid2D;
import edu.stanford.rsl.conrad.data.numeric.opencl.OpenCLGrid3D;
import edu.stanford.rsl.conrad.geometry.shapes.simple.PointND;
import edu.stanford.rsl.conrad.geometry.trajectories.HelicalTrajectory;
import edu.stanford.rsl.conrad.geometry.trajectories.Trajectory;
import edu.stanford.rsl.conrad.numerics.SimpleVector;
import edu.stanford.rsl.conrad.opencl.OpenCLUtil;
import edu.stanford.rsl.conrad.phantom.NumericalSheppLogan3D;
import edu.stanford.rsl.conrad.utils.Configuration;
import edu.stanford.rsl.tutorial.cone.ConeBeamBackprojector;
import edu.stanford.rsl.tutorial.cone.ConeBeamProjector;
public class SartCL implements Sart{
//debug
boolean debug = false;
// controll structures
protected CLContext context = null;
//geometry
protected Trajectory geo = null;
protected int maxProjs;
protected int width;
protected int height;
protected int imgSize[];
protected double spacing[];
protected double origin[];
private long starttime = 0;
//grids
protected OpenCLGrid3D volCL;
private OpenCLGrid3D updBP;
private OpenCLGrid2D sinoCL;
protected OpenCLGrid3D normSino;
protected OpenCLGrid3D normGrid;
protected OpenCLGrid3D oProj = null;
private OpenCLGrid2D[] oProjP;
private OpenCLGrid2D[] normSinoP;
protected CLCommandQueue queueIterative;
// memory for fast calculations
protected CLBuffer<FloatBuffer> projMatrices = null;
protected CLBuffer<FloatBuffer> gInvARmatrix = null;
protected CLBuffer<FloatBuffer> gSrcPoint = null;
// buffer for 3D volume:
protected Grid3D volumeResult;
//sart variables
protected final float beta;
protected float normFactor;
//projectors
protected ConeBeamProjector cbp;
protected ConeBeamBackprojector cbbp;
float eps = (float)1.0e-10;
/**
* This constructor takes the following arguments:
* @param volDims
* @param spacing
* @param origin
* @param oProj
* @param beta
* @throws Exception
*/
public SartCL(int[] volDims, double[] spacing, double[] origin, OpenCLGrid3D oProj, float beta) throws Exception {
if (null == oProj) {
throw new Exception("SART: No projection data given");
}
if (1 > volDims[0] || 1 > volDims[1] || 1 > volDims[2]) {
throw new Exception(
"SART: Span of each dimension in the volume has to be a natural number");
}
// create initial volume filled with zeros
initCLDatastructure();
volCL = new OpenCLGrid3D(new Grid3D(volDims[0], volDims[1], volDims[2]));
volCL.setOrigin(origin);
volCL.setSpacing(spacing);
cbp = new ConeBeamProjector();
cbbp = new ConeBeamBackprojector();
this.normSino = createNormProj();
this.normGrid = createNormGrid();
this.oProj = new OpenCLGrid3D(oProj);
this.oProj.getDelegate().prepareForDeviceOperation();
queueIterative.putCopyBuffer(oProj.getDelegate().getCLBuffer(), this.oProj.getDelegate().getCLBuffer()).finish();
/* calculated once for speedup */
this.normFactor = (float) (this.oProj.getGridOperator().normL1(this.oProj) / normSino.getGridOperator().normL1(normSino));
this.beta = beta;
sinoCL = new OpenCLGrid2D(new Grid2D(width,height));
sinoCL.setSpacing(geo.getPixelDimensionX(),geo.getPixelDimensionY());
updBP = new OpenCLGrid3D(new Grid3D(imgSize[0],imgSize[1],imgSize[2]));
updBP.setOrigin(-origin[0], -origin[1], -origin[2]);
updBP.setSpacing(spacing[0], spacing[1], spacing[2]);
oProjP = new OpenCLGrid2D[maxProjs];
normSinoP = new OpenCLGrid2D[maxProjs];
for (int j= 0; j< maxProjs; ++j){
oProjP[j] = new OpenCLGrid2D(oProj.getSubGrid(j));
normSinoP[j] = new OpenCLGrid2D(normSino.getSubGrid(j));
}
}
public SartCL(Grid3D initialVol, Grid3D sino, float beta) throws Exception {
if (null == initialVol) {
throw new Exception("SART: No initial volume given");
}
if (null == sino) {
throw new Exception("SART: No projection data given");
}
initCLDatastructure();
volCL = new OpenCLGrid3D(initialVol);
volCL.setOrigin(initialVol.getOrigin());
volCL.setSpacing(initialVol.getSpacing());
cbp = new ConeBeamProjector();
cbbp = new ConeBeamBackprojector();
this.normSino = createNormProj();
this.normGrid = createNormGrid();
this.oProj = new OpenCLGrid3D(sino);
/* calculated once for speedup*/
this.beta = beta;
sinoCL = new OpenCLGrid2D(new Grid2D(width,height));
sinoCL.setSpacing(geo.getPixelDimensionX(),geo.getPixelDimensionY());
updBP = new OpenCLGrid3D(new Grid3D(imgSize[0],imgSize[1],imgSize[2]));
updBP.setOrigin(-origin[0], -origin[1], -origin[2]);
updBP.setSpacing(spacing[0], spacing[1], spacing[2]);
oProjP = new OpenCLGrid2D[maxProjs];
normSinoP = new OpenCLGrid2D[maxProjs];
for (int j= 0; j< maxProjs; ++j){
oProjP[j] = new OpenCLGrid2D(oProj.getSubGrid(j));
normSinoP[j] = new OpenCLGrid2D(normSino.getSubGrid(j));
}
}
/**
* @return Normalized projection data
* @throws Exception
*/
protected OpenCLGrid3D createNormProj() throws Exception {
OpenCLGrid3D onesVol = new OpenCLGrid3D(new Grid3D(volCL.getSize()[0],volCL.getSize()[1],volCL.getSize()[2]));
onesVol.getGridOperator().fill(onesVol, 1);
OpenCLGrid3D sino = new OpenCLGrid3D(new Grid3D(width,height,maxProjs));
cbp.fastProjectRayDrivenCL(sino,onesVol);
onesVol.release();
// prevent div by zero
float min = sino.getGridOperator().min(sino);
//sino.show("Sino of norm projection");
if (0 >= min)
sino.getGridOperator().addBy(sino,eps);
return sino;
}
protected OpenCLGrid3D createNormGrid() throws Exception {
OpenCLGrid3D normGrid = new OpenCLGrid3D(new Grid3D(imgSize[0],imgSize[1],imgSize[2]));
OpenCLGrid2D sinoCLOnes = new OpenCLGrid2D(new Grid2D(width,height));
sinoCLOnes.getGridOperator().fill(sinoCLOnes, 1.0f);
normGrid.getDelegate().prepareForDeviceOperation();
cbbp.fastBackprojectPixelDrivenCL(sinoCLOnes,normGrid, 0);
normGrid.getGridOperator().fill(normGrid, (float)normGrid.getGridOperator().sum(normGrid)/(normGrid.getNumberOfElements()));
sinoCLOnes.release();
return normGrid;
}
public final void iterate() throws Exception{
iterate(1);
}
public final void iterate(final int iter) throws Exception {
for (int i = 0; i < iter; ++i) {
if(debug) {
System.out.println("Starting Sart Iteration "+i);
starttime = System.currentTimeMillis();
}
boolean[] projIsUsed = new boolean[maxProjs]; // default: false
int p = 0; // current projection index
for (int n = 0; n < maxProjs; ++n) {
p=n;
/* edit/init data structures */
projIsUsed[p] = true;
cbp.fastProjectRayDrivenCL(sinoCL, volCL, p);
/* update step */
// NOTE: upd = (oProj - sino) ./ normSino
sinoCL.getGridOperator().subtractBy(sinoCL, oProjP[p]);
sinoCL.getGridOperator().divideBySave(sinoCL, normSinoP[p]);
updBP.getGridOperator().fill(updBP, 0);
cbbp.fastBackprojectPixelDrivenCL(sinoCL,updBP,p);
// NOTE: vol = vol + updBP * beta
updBP.getGridOperator().multiplyBySave(updBP, -beta);
updBP.getGridOperator().divideBy(updBP, normGrid);
volCL.getGridOperator().addBy(volCL, updBP);
volCL.getGridOperator().removeNegative(volCL);
/*
* Don't use projections with a small angle to each other
* subsequently
*/
p = (p + maxProjs / 4) % maxProjs;
for (int ii = 1; projIsUsed[p] && ii < maxProjs; ++ii)
p = (p + 1) % maxProjs;
}
if(debug) {
long endtime = System.currentTimeMillis()-starttime;
System.out.print("Iter "+i+": ");
System.out.format("%.3f seconds\n", ((double) endtime) / 1000.0);
System.out.println(volCL.getGridOperator().normL1(volCL));
System.out.println("End of Sart Iteration ");
}
}
volCL.getGridOperator().fillInvalidValues(volCL,0);
}
public Grid3D getVol(){
volumeResult = new Grid3D(volCL);
volumeResult.setOrigin(-geo.getOriginInPixelsX(),-geo.getOriginInPixelsY(),-geo.getOriginInPixelsZ());
volumeResult.setSpacing(geo.getVoxelSpacingX(),geo.getVoxelSpacingY(),geo.getVoxelSpacingZ());
return volumeResult;
}
private void initCLDatastructure(){
Configuration.loadConfiguration();
Configuration conf = Configuration.getGlobalConfiguration();
geo = conf.getGeometry();
width = geo.getDetectorWidth();
height = geo.getDetectorHeight();
maxProjs = geo.getProjectionStackSize();
// create context
context = OpenCLUtil.getStaticContext();
queueIterative = context.getMaxFlopsDevice().createCommandQueue();
//image variables
imgSize = new int[3];
origin = new double[3];
spacing = new double[3];
imgSize[0] = geo.getReconDimensionX();
imgSize[1] = geo.getReconDimensionY();
imgSize[2] = geo.getReconDimensionZ();
spacing[0] = geo.getVoxelSpacingX();
spacing[1] = geo.getVoxelSpacingY();
spacing[2] = geo.getVoxelSpacingZ();
origin[0] = -geo.getOriginX();
origin[1] = -geo.getOriginY();
origin[2] = -geo.getOriginZ();
}
public static void main(String[] args) throws Exception {
boolean helix = false;
int iterations = 2;
Configuration.loadConfiguration();
Configuration conf = Configuration.getGlobalConfiguration();
Trajectory traj = conf.getGeometry();
if(helix){
traj = new HelicalTrajectory(Configuration.getGlobalConfiguration().getGeometry());
// set chosen trajectory
int stepHel =traj.getNumProjectionMatrices();
double physicalDetectorHeight = traj.getDetectorHeight()*traj.getPixelDimensionY();
double stepSize = (physicalDetectorHeight*0.05 / (((double)stepHel)));
double volumeZSize = physicalDetectorHeight*3*0.05;
//set chosen trajectory
((HelicalTrajectory) traj).setTrajectory( stepHel*3, Configuration.getGlobalConfiguration().getGeometry().getSourceToAxisDistance(), traj.getAverageAngularIncrement(),
traj.getDetectorOffsetU(), traj.getDetectorOffsetV(),
traj.getDetectorUDirection(), traj.getDetectorVDirection(),
new SimpleVector(0,0,1), new PointND(0,0,volumeZSize / 2), 0, stepSize);
conf.setGeometry(traj);
}
new ImageJ();
ConeBeamProjector cbp = new ConeBeamProjector();
OpenCLGrid3D grid = new OpenCLGrid3D(new NumericalSheppLogan3D(traj.getReconDimensionX(),traj.getReconDimensionY(), traj.getReconDimensionZ()).getNumericalSheppLoganPhantom());
grid.setOrigin(-traj.getOriginInPixelsX(),-traj.getOriginInPixelsY(),-traj.getOriginInPixelsZ());
grid.setSpacing(traj.getVoxelSpacingX(),traj.getVoxelSpacingY(),traj.getVoxelSpacingZ());
System.out.println("GT: "+grid.getGridOperator().normL1(grid));
try {
OpenCLGrid3D sino = new OpenCLGrid3D(new Grid3D(traj.getDetectorWidth(),traj.getDetectorHeight(),traj.getProjectionStackSize()));
sino.setOrigin(0,0,0);
sino.setSpacing(1,1,1);
cbp.fastProjectRayDrivenCL(sino,grid);
SartCL sart = new SartCL(grid.getSize(),grid.getSpacing(),grid.getOrigin(), sino, 0.8f);
long start = System.currentTimeMillis();
sart.iterate(iterations);
long ende = System.currentTimeMillis()-start;
System.out.format("Time Sart: %.5f seconds\n", ((double) ende) / 1000.0);
Grid3D foo = sart.getVol();
System.out.println("L1: "+foo.getGridOperator().normL1(foo));
System.out.println("RMSE: "+foo.getGridOperator().rmse(foo,grid));
foo.show();
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}