package edu.stanford.rsl.tutorial.iterative; import ij.ImageJ; import java.nio.FloatBuffer; import com.jogamp.opencl.CLBuffer; import com.jogamp.opencl.CLContext; import edu.stanford.rsl.conrad.data.numeric.Grid1D; import edu.stanford.rsl.conrad.data.numeric.Grid2D; import edu.stanford.rsl.conrad.data.numeric.Grid3D; import edu.stanford.rsl.conrad.data.numeric.opencl.OpenCLGrid1D; import edu.stanford.rsl.conrad.data.numeric.opencl.OpenCLGrid2D; import edu.stanford.rsl.conrad.data.numeric.opencl.OpenCLGrid3D; import edu.stanford.rsl.conrad.geometry.Projection.CameraAxisDirection; import edu.stanford.rsl.conrad.geometry.trajectories.CircularTrajectory; import edu.stanford.rsl.conrad.geometry.trajectories.Trajectory; import edu.stanford.rsl.conrad.numerics.SimpleVector; import edu.stanford.rsl.conrad.opencl.OpenCLUtil; import edu.stanford.rsl.conrad.phantom.NumericalSheppLogan3D; import edu.stanford.rsl.conrad.utils.Configuration; import edu.stanford.rsl.tutorial.cone.ConeBeamBackprojector; import edu.stanford.rsl.tutorial.cone.ConeBeamProjector; import edu.stanford.rsl.tutorial.fan.FanBeamBackprojector2D; import edu.stanford.rsl.tutorial.fan.FanBeamProjector2D; import edu.stanford.rsl.tutorial.parallel.ParallelBackprojector2D; import edu.stanford.rsl.tutorial.parallel.ParallelProjector2D; import edu.stanford.rsl.tutorial.phantoms.MickeyMouseGrid2D; import edu.stanford.rsl.tutorial.phantoms.Phantom3D; import edu.stanford.rsl.tutorial.phantoms.Sphere3D; /* * TODO: implement parrallel beam reoncstruction */ public class SartCL2D{ //debug boolean debug = false; // controll structures protected CLContext context = null; //geometry protected enum reconGeometry {PARALLEL, FAN}; protected static reconGeometry usedGeometry = reconGeometry.FAN; //geometry protected Trajectory geo = null; protected int width; protected int height; int imgSizeX; int imgSizeY; double spacingX; double spacingY; double originX; double originY; double focalLength; double maxBeta; double deltaBeta; double maxT; double deltaT; int maxBetaIndex; int maxTIndex; long starttime = 0; //grids protected OpenCLGrid2D volCL; private OpenCLGrid2D updBP; private OpenCLGrid2D sinoCL; protected OpenCLGrid2D normSino; protected OpenCLGrid2D normGrid; protected OpenCLGrid2D oProj = null; // memory for fast calculations protected CLBuffer<FloatBuffer> projMatrices = null; protected CLBuffer<FloatBuffer> gInvARmatrix = null; protected CLBuffer<FloatBuffer> gSrcPoint = null; // buffer for 3D volume: protected Grid2D volumeResult; //sart variables protected final float beta; protected float normFactor; //projectors protected FanBeamProjector2D fbp; protected FanBeamBackprojector2D fbbp; protected ParallelProjector2D pbp; protected ParallelBackprojector2D pbbp; float eps = (float)1.0e-10; /** * This constructor takes the following arguments: * @param volDims * @param spacing * @param origin * @param oProj * @param beta * @throws Exception */ public SartCL2D(int[] volDims, double[] spacing, double[] origin, Grid2D oProj, float beta) throws Exception { if (null == oProj) { throw new Exception("SART: No projection data given"); } if (1 > volDims[0] || 1 > volDims[1]) { throw new Exception( "SART: Span of each dimension in the volume has to be a natural number"); } // create initial volume filled with zeros initCLDatastructure(); volCL = new OpenCLGrid2D(new Grid2D(volDims[0], volDims[1])); volCL.setOrigin(origin); volCL.setSpacing(spacing); if (usedGeometry == reconGeometry.FAN){ fbp = new FanBeamProjector2D(focalLength, maxBeta, deltaBeta, maxT, deltaT); fbbp = new FanBeamBackprojector2D(focalLength, deltaT, deltaBeta, imgSizeX, imgSizeY); } else { pbp = new ParallelProjector2D(maxBeta, deltaBeta, maxT, deltaT); //pbbp = new ParallelBackprojector2D(imgSizeX,imgSizeY,pxSzXMM,pxSzYMM); } this.normSino = createNormProj(); this.normGrid = createNormGrid(); this.oProj = new OpenCLGrid2D(oProj); /* calculated once for speedup */ this.normFactor = (float) (this.oProj.getGridOperator().normL1(this.oProj) / normSino.getGridOperator().normL1(normSino)); this.beta = beta; sinoCL = new OpenCLGrid2D(new Grid2D(maxTIndex,maxBetaIndex)); sinoCL.setSpacing(deltaT, deltaBeta); updBP = new OpenCLGrid2D(new Grid2D(imgSizeX,imgSizeY)); updBP.setOrigin(volCL.getOrigin()[0], volCL.getOrigin()[1]); updBP.setSpacing(volCL.getSpacing()[0], volCL.getSpacing()[1]); } public SartCL2D(Grid2D initialVol, Grid2D sino, float beta) throws Exception { if (null == initialVol) { throw new Exception("SART: No initial volume given"); } if (null == sino) { throw new Exception("SART: No projection data given"); } initCLDatastructure(); volCL = new OpenCLGrid2D(initialVol); if (usedGeometry == reconGeometry.FAN){ fbp = new FanBeamProjector2D(focalLength, maxBeta, deltaBeta, maxT, deltaT); fbbp = new FanBeamBackprojector2D(focalLength, deltaT, deltaBeta, imgSizeX, imgSizeY); } else { pbp = new ParallelProjector2D(maxBeta, deltaBeta, maxT, deltaT); //pbbp = new ParallelBackprojector2D(imgSzXMM,imgSzYMM,pxSzXMM,pxSzYMM); } this.normSino = createNormProj(); this.normGrid = createNormGrid(); this.oProj = new OpenCLGrid2D(sino); /* calculated once for speedup*/ this.beta = beta; sinoCL = new OpenCLGrid2D(new Grid2D(maxTIndex,maxBetaIndex)); sinoCL.setSpacing(deltaT, deltaBeta); updBP = new OpenCLGrid2D(new Grid2D(imgSizeX,imgSizeY)); updBP.setOrigin(-originX, -originY); updBP.setSpacing(spacingX, spacingY); } /** * @return Normalized projection data * @throws Exception */ protected OpenCLGrid2D createNormProj() throws Exception { OpenCLGrid2D onesVol = new OpenCLGrid2D(new Grid2D(volCL.getSize()[0],volCL.getSize()[1])); onesVol.getGridOperator().fill(onesVol, 1); OpenCLGrid2D sino = new OpenCLGrid2D(new Grid2D(maxTIndex,maxBetaIndex)); if (usedGeometry == reconGeometry.FAN){ fbp.fastProjectRayDrivenCL(sino,onesVol); } else { //pbp.fastProjectRayDrivenCL(sino,onesVol); } onesVol.release(); // prevent div by zero float min = sino.getGridOperator().min(sino); //sino.show("Sino of norm projection"); if (0 >= min) sino.getGridOperator().addBy(sino,eps); sino.setSpacing(deltaT, deltaBeta); return sino; } protected OpenCLGrid2D createNormGrid() throws Exception { OpenCLGrid2D normGrid = new OpenCLGrid2D(new Grid2D(imgSizeX,imgSizeY)); OpenCLGrid2D sinoCLOnes = new OpenCLGrid2D(new Grid2D(maxTIndex,maxBetaIndex)); sinoCLOnes.getGridOperator().fill(sinoCLOnes, 1.0f); sinoCLOnes.setSpacing(deltaT, deltaBeta); normGrid.getDelegate().prepareForDeviceOperation(); if (usedGeometry == reconGeometry.FAN){ fbbp.fastBackprojectPixelDrivenCL(sinoCLOnes,normGrid); } else { //pbbp.fastBackProjectRayDrivenCL(sinoCLOnes,normGrid, 0); } normGrid.getGridOperator().fill(normGrid, (float)normGrid.getGridOperator().sum(normGrid)/(normGrid.getNumberOfElements())); sinoCLOnes.release(); return normGrid; } public final void iterate() throws Exception{ iterate(1); } public final void iterate(final int iter) throws Exception { for (int i = 0; i < iter; ++i) { if(debug) { System.out.println("Starting Sart Iteration "+i); starttime = System.currentTimeMillis(); } if (usedGeometry == reconGeometry.FAN){ fbp.fastProjectRayDrivenCL(sinoCL, volCL); } else { //pbp.fastProjectRayDrivenCL(sinoCL, volCL, p); } /* update step */ // NOTE: upd = (oProj - sino) ./ normSino sinoCL.getGridOperator().subtractBy(sinoCL, oProj); sinoCL.getGridOperator().divideBySave(sinoCL, normSino); updBP.getGridOperator().fill(updBP, 0); if (usedGeometry == reconGeometry.FAN){ fbbp.fastBackprojectPixelDrivenCL(sinoCL, updBP); } else { //pbbp.fastBackprojectPixelDrivenCL(sinoCL, updBP, p); } // NOTE: vol = vol + updBP * beta updBP.getGridOperator().multiplyBySave(updBP, -beta); updBP.getGridOperator().divideBy(updBP, normGrid); volCL.getGridOperator().addBy(volCL, updBP); volCL.getGridOperator().removeNegative(volCL); } if(debug) { long endtime = System.currentTimeMillis()-starttime; System.out.format("%.3f seconds\n", ((double) endtime) / 1000.0); System.out.println("End of Sart Iteration "); } volCL.getGridOperator().fillInvalidValues(volCL,0); } public Grid2D getVol2D(){ volumeResult = new Grid2D(volCL); volumeResult.setOrigin(-geo.getOriginInPixelsX(),-geo.getOriginInPixelsY()); volumeResult.setSpacing(geo.getVoxelSpacingX(),geo.getVoxelSpacingY()); return volumeResult; } private void initCLDatastructure(){ Configuration conf = Configuration.getGlobalConfiguration(); geo = conf.getGeometry(); width = geo.getDetectorWidth(); // create context context = OpenCLUtil.getStaticContext(); imgSizeX = geo.getReconDimensionX(); imgSizeY = geo.getReconDimensionY(); spacingX = geo.getVoxelSpacingX(); spacingY = geo.getVoxelSpacingY(); originX = -geo.getOriginX(); originY = -geo.getOriginY(); focalLength = geo.getSourceToDetectorDistance(); maxBeta = geo.getAverageAngularIncrement()*geo.getProjectionStackSize()*Math.PI/180.0; deltaBeta = maxBeta/geo.getProjectionStackSize(); maxT = geo.getDetectorWidth(); deltaT = geo.getPixelDimensionX(); maxBetaIndex = (int) (maxBeta / deltaBeta); maxTIndex = (int) (maxT / deltaT); height = maxTIndex; } public static void main(String[] args) { Configuration.loadConfiguration(); Configuration conf = Configuration.getGlobalConfiguration(); CircularTrajectory traj = new CircularTrajectory(conf.getGeometry()); ((CircularTrajectory) traj).setTrajectory( 36, Configuration.getGlobalConfiguration().getGeometry().getSourceToAxisDistance(), 10.0, 0, 0, CameraAxisDirection.DETECTORMOTION_PLUS, CameraAxisDirection.ROTATIONAXIS_PLUS, new SimpleVector(0,0,1)); conf.setGeometry(traj); new ImageJ(); double focalLength = traj.getSourceToDetectorDistance(); double maxBeta = traj.getAverageAngularIncrement()*traj.getProjectionStackSize()*Math.PI/180.0; double deltaBeta = maxBeta/traj.getProjectionStackSize(); double maxT = traj.getDetectorWidth(); double deltaT = traj.getPixelDimensionX(); int maxBetaIndex = (int) (maxBeta / deltaBeta); int maxTIndex = (int) (maxT / deltaT); FanBeamProjector2D fbp = new FanBeamProjector2D(focalLength,maxBeta,deltaBeta,maxT,deltaT); MickeyMouseGrid2D test2D = new MickeyMouseGrid2D(traj.getReconDimensionX(), traj.getReconDimensionY()); OpenCLGrid2D grid = new OpenCLGrid2D(test2D); grid.setSpacing(traj.getVoxelSpacingX(),traj.getVoxelSpacingY()); grid.setOrigin(-(traj.getReconDimensionX() * grid.getSpacing()[0]) / 2, -(traj.getReconDimensionY()* grid.getSpacing()[1]) / 2); System.out.println("GT: "+grid.getGridOperator().normL1(grid)); try { OpenCLGrid2D sino = new OpenCLGrid2D(new Grid2D(maxTIndex,maxBetaIndex)); fbp.fastProjectRayDrivenCL(sino,grid); SartCL2D sart = new SartCL2D(grid.getSize(), grid.getSpacing(), grid.getOrigin(), sino, 0.8f); long start = System.currentTimeMillis(); sart.iterate(100); long ende = System.currentTimeMillis()-start; System.out.format("Time Sart: %.5f seconds\n", ((double) ende) / 1000.0); Grid2D foo = sart.getVol2D(); System.out.println("L1: "+foo.getGridOperator().normL1(foo)); foo.show(); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } } }