package edu.stanford.rsl.tutorial.iterative; import ij.ImageJ; import edu.stanford.rsl.conrad.data.numeric.Grid2D; import edu.stanford.rsl.conrad.data.numeric.Grid3D; import edu.stanford.rsl.conrad.data.numeric.NumericGridOperator; import edu.stanford.rsl.conrad.data.numeric.NumericPointwiseOperators; import edu.stanford.rsl.conrad.data.numeric.opencl.OpenCLGrid2D; import edu.stanford.rsl.conrad.data.numeric.opencl.OpenCLGrid3D; import edu.stanford.rsl.conrad.filtering.TruncationCorrectionTool; import edu.stanford.rsl.conrad.filtering.redundancy.ParkerWeightingTool; import edu.stanford.rsl.conrad.geometry.trajectories.Trajectory; import edu.stanford.rsl.conrad.phantom.NumericalSheppLogan3D; import edu.stanford.rsl.conrad.utils.Configuration; import edu.stanford.rsl.tutorial.cone.ConeBeamBackprojector; import edu.stanford.rsl.tutorial.cone.ConeBeamCosineFilter; import edu.stanford.rsl.tutorial.cone.ConeBeamProjector; import edu.stanford.rsl.tutorial.fan.FanBeamBackprojector2D; import edu.stanford.rsl.tutorial.fan.FanBeamProjector2D; import edu.stanford.rsl.tutorial.filters.RamLakKernel; import edu.stanford.rsl.tutorial.parallel.ParallelBackprojector2D; import edu.stanford.rsl.tutorial.parallel.ParallelProjector2D; import edu.stanford.rsl.tutorial.phantoms.DotsGrid2D; import edu.stanford.rsl.tutorial.phantoms.MickeyMouseGrid2D; import edu.stanford.rsl.tutorial.phantoms.Phantom; import edu.stanford.rsl.tutorial.phantoms.Phantom3D; import edu.stanford.rsl.tutorial.phantoms.Sphere3D; import edu.stanford.rsl.tutorial.phantoms.TestObject1; import edu.stanford.rsl.tutorial.phantoms.UniformCircleGrid2D; /** * A test class for the computation and display of iterative reconstructions. * * @author Mario Amrehn * */ public class IterativeReconstructionTest { private enum MyPhantom {PSPHERE, PSHEPP}; // ----------------------------------------- private final static boolean USE_CL_FP = true; // GPU acceleration private final static boolean USE_CL_BP = true; // GPU acceleration private final static float sartRelax = 0.8f; private final static int sartIterations = 100; private final static MyPhantom phan = MyPhantom.PSHEPP; // PSHEPP, PSPHERE private final static float omega = 0.8f; private final static int gdIter = 10; private final static float regul = (float) Math.pow(10, -4); private final static float initStepsize = 0.3f; private final static boolean computeFDKReco = true; private final static boolean computeSARTCLReco = true; private final static boolean computeSARTCPUReco = false; private final static boolean computeETVReco = true; private static boolean truncationCorrection = true; // ----------------------------------------- public static void main(String[] args) { new ImageJ(); Configuration.loadConfiguration(); Configuration conf = Configuration.getGlobalConfiguration(); Trajectory geo = conf.getGeometry(); double focalLength = geo.getSourceToDetectorDistance(); // int maxU = geo.getDetectorWidth(); // int maxV = geo.getDetectorHeight(); int maxU_PX = geo.getDetectorWidth(); int maxV_PX = geo.getDetectorHeight(); double deltaU = geo.getPixelDimensionX(); double deltaV = geo.getPixelDimensionY(); double maxU = (maxU_PX) * deltaU; double maxV = (maxV_PX) * deltaV; int imgSizeX = geo.getReconDimensionX(); int imgSizeY = geo.getReconDimensionY(); int imgSizeZ = geo.getReconDimensionZ(); double imgSpacingX = geo.getVoxelSpacingX(); double imgSpacingY = geo.getVoxelSpacingY(); double imgSpacingZ= geo.getVoxelSpacingZ(); double originX = geo.getOriginX(); double originY = geo.getOriginY(); double originZ = geo.getOriginZ(); OpenCLGrid3D grid = new OpenCLGrid3D(getInput(phan, imgSizeX, imgSizeY, imgSizeZ, imgSpacingX, imgSpacingY, imgSpacingZ, originX, originY, originZ)); grid.show("object"); NumericGridOperator gop = grid.getGridOperator(); long timeFBP=1, timeSARTCL=-1, timeSARTCPU=-1, timeETV=-1; Grid3D recImageFBP=null, recImageIterSartCL=null, recImageIterSartCPU=null, recImageIterETV=null; if (computeFDKReco) { timeFBP = System.currentTimeMillis(); recImageFBP = reconstructFBP(grid, focalLength, maxU, maxV, deltaU, deltaV, maxU_PX, maxV_PX, conf, geo); timeFBP = System.currentTimeMillis() - timeFBP; recImageFBP.show("recImageFBP, Min: " + gop.min(recImageFBP) + " Max: " + gop.max(recImageFBP)); } if (computeSARTCLReco) { timeSARTCL = System.currentTimeMillis(); recImageIterSartCL = reconstructSART(grid, sartRelax, sartIterations, true); timeSARTCL = System.currentTimeMillis() - timeSARTCL; recImageIterSartCL.show("recImageIterSART_CL"); } if (computeSARTCPUReco) { timeSARTCPU = System.currentTimeMillis(); recImageIterSartCPU = reconstructSART(grid, sartRelax, sartIterations, false); timeSARTCPU = System.currentTimeMillis() - timeSARTCPU; recImageIterSartCPU.show("recImageIterSART_CPU"); } if (computeETVReco) { timeETV = System.currentTimeMillis(); recImageIterETV = reconstructETV(grid, sartRelax, sartIterations, omega, gdIter, regul, initStepsize); timeETV = System.currentTimeMillis() - timeETV; recImageIterETV.show("recImageIterETV"); } System.out .format("Config (Phantom %s):\n SART: Relax %.2f, Iter %d\n GD: Iter %d, Regul %.2e, Step %.2f\n Omega %.2f\n", phan, sartRelax, sartIterations, gdIter, regul, initStepsize, omega); System.out.println("Errors:"); double rmseFBP = -1, rmseSART_CL = -1, rmseSART_CPU = -1, rmseETV = -1; if (computeFDKReco) { rmseFBP = gop.rmse(recImageFBP, grid); System.out.format(" RMSE(GT,FBP): %.7f\n", rmseFBP); } if (computeSARTCLReco) { rmseSART_CL = gop.rmse(recImageIterSartCL, grid); System.out.format(" RMSE(GT,SART_CL): %.7f\n", rmseSART_CL); } if (computeSARTCPUReco) { rmseSART_CPU = gop.rmse(recImageIterSartCPU, grid); System.out.format(" RMSE(GT,SART_CPU): %.7f\n", rmseSART_CPU); } if (computeETVReco) { rmseETV = gop.rmse(recImageIterETV, grid); System.out.format(" RMSE(GT,eTV): %.7f\n", rmseETV); } System.out.println("Norms:"); double l1GT = -1, l1FBP = -1, l1SART_CL = -1, l1SART_CPU = -1, l1ETV = -1; l1GT = gop.normL1(grid); System.out.format(" l1(GT): %.2f\n", l1GT); if (computeFDKReco) { l1FBP = gop.normL1(recImageFBP); System.out.format(" l1(FBP): %.2f\n", l1FBP); } if (computeSARTCLReco) { l1SART_CL = gop.normL1(recImageIterSartCL); System.out.format(" l1(SART_CL): %.2f\n", l1SART_CL); } if (computeSARTCPUReco) { l1SART_CPU = gop.normL1(recImageIterSartCPU); System.out.format(" l1(SART_CPU): %.2f\n", l1SART_CPU); } if (computeETVReco) { l1ETV = gop.normL1(recImageIterETV); System.out.format(" l1(eTV): %.2f\n", l1ETV); } System.out.println("Time:"); if (computeFDKReco) System.out.format(" time(FBP): %.1f seconds\n", ((double) timeFBP) / 1000.0); if (computeSARTCLReco) System.out.format(" time(SART_CL): %.1f%%\n", ((double) timeSARTCL * 100)/ timeFBP); if (computeSARTCPUReco) System.out.format(" time(SART_CPU): %.1f%%\n", ((double) timeSARTCPU * 100) / timeFBP); if (computeETVReco) System.out.format(" time(eTV): %.1f%%\n", ((double) timeETV * 100) / timeFBP); /* if (computeFDKReco && computeSARTCPUReco && computeSARTCLReco) { System.out .format("Norms:\n l1(GT): %.2f\n l1(FBP): %.2f\n l1(SART_CL): %.2f\n l1(SART:CPU): %.2f\n", GridOp.l1Norm(grid), GridOp.l1Norm(recImageFBP), GridOp.l1Norm(recImageIterSartCL), GridOp.l1Norm(recImageIterSartCPU)); System.out.format( "Recon time: FBP: %.1f seconds. SART_CL: %.1f%%, SART_CPU: %.1f%%", ((double) timeFBP) / 1000.0, ((double) timeSARTCL * 100) / timeFBP, ((double) timeSARTCPU * 100) / timeFBP); } if (computeETVReco && computeFDKReco && computeSARTCPUReco) { System.out .format("Norms:\n l1(GT): %.2f\n l1(FBP): %.2f\n l1(SART): %.2f\n l1(eTV): %.2f\n", GridOp.l1Norm(grid), GridOp.l1Norm(recImageFBP), GridOp.l1Norm(recImageIterSartCPU), GridOp.l1Norm(recImageIterETV)); System.out.format( "Recon time: FBP: %.1f seconds. SART: %.1f%%, eTV: %.1f%%", ((double) timeFBP) / 1000.0, ((double) timeSARTCPU * 100) / timeFBP, ((double) timeETV * 100) / timeFBP); } */ // Config (Phantom PSPHERE): // SART: Relax 0,80, Iter 400 // GD: Iter 40, Regul 1,00e-04, Step 0,30 // Omega 0,80 // Errors: // RMSE(GT,FBP): 0,2315102 // RMSE(GT,SART): 0,3176485 // RMSE(GT,eTV): 0,3176485 // Norms: // l1(GT): 389,00 // l1(FBP): 354,79 // l1(SART): 346,35 // l1(eTV): 346,35 // Recon time: FBP: 3,3 seconds. SART: 13085,2%, eTV: 13409,9% // if (true) // return; } private static Grid3D getInput(MyPhantom phan, int imgSizeX, int imgSizeY, int imgSizeZ, double imgSpacingX, double imgSpacingY, double imgSpacingZ, double originX, double originY, double originZ) { Grid3D res = null; switch (phan) { case PSPHERE: Phantom3D test3D = new Sphere3D(imgSizeX, imgSizeY, imgSizeZ); res = test3D; break; case PSHEPP: NumericalSheppLogan3D shepp3d = new NumericalSheppLogan3D(imgSizeX, imgSizeY, imgSizeZ); res = shepp3d.getNumericalSheppLoganPhantom(); break; default: res = new Sphere3D(imgSizeX, imgSizeY, imgSizeZ); } if(null != res) res.setSpacing(imgSpacingX, imgSpacingY, imgSpacingZ); res.setOrigin(originX, originY, originZ); return res; } private static Grid3D reconstructETV(Grid3D grid, float sartRelax, int eTvIerations, float omega, int gdIter, float regul, float initStepsize) { ConeBeamProjector cbp = new ConeBeamProjector(); Grid3D sino = USE_CL_FP ? cbp.projectRayDrivenCL(grid) : cbp .projectPixelDriven(grid); Etv reconEtv = null; try { reconEtv = new Etv(grid.getSize(), grid.getSpacing(), grid.getOrigin(), new OpenCLGrid3D(sino), sartRelax, omega, gdIter, regul, initStepsize); reconEtv.iterateETV(eTvIerations); } catch (Exception e) { e.printStackTrace(); } if (null == reconEtv) { System.err.println("Error creating eTV reconstruction instance"); return null; } Grid3D recImageIterEtv = reconEtv.getVol(); return recImageIterEtv; } private static Grid3D reconstructSART(Grid3D grid, float sartRelax, int sartIterations, boolean USE_CL_SART) { ConeBeamProjector cbp = new ConeBeamProjector(); Grid3D sino = USE_CL_FP ? cbp.projectRayDrivenCL(grid) : cbp .projectPixelDriven(grid); sino.show("sinoCL-SART"); Sart reconSart = null; try { if(USE_CL_SART) reconSart = new SartCL(grid.getSize(), grid.getSpacing(), grid.getOrigin(), new OpenCLGrid3D(sino), sartRelax); else reconSart = new SartCPU(grid.getSize(), grid.getSpacing(), grid.getOrigin(), sino, sartRelax); reconSart.iterate(sartIterations); } catch (Exception e) { e.printStackTrace(); } if (null == reconSart) { System.err.println("Error creating SART reconstruction instance"); return null; } Grid3D recImageIterSart = reconSart.getVol(); return recImageIterSart; } /* private static Grid3D reconstructFBPNew(Grid3D grid, double focalLength, double maxU, double maxV, double deltaU, double deltaV, int maxU_PX, int maxV_PX, Configuration conf, Trajectory geo) { ConeBeamProjector cbp = new ConeBeamProjector(); Grid3D sino = USE_CL_FP ? cbp.projectRayDrivenCL(grid) : cbp .projectPixelDriven(grid); String sinoType = USE_CL_FP ? "CL" : "CPU"; sino.show("sino" + sinoType); ConeBeamCosineFilter cbFilter = new ConeBeamCosineFilter(focalLength, maxU, maxV, deltaU, deltaV); RamLakKernel ramKRampFilter = new RamLakKernel(maxU_PX, deltaU); int stacksize = conf.getGeometry().getProjectionStackSize(); double D = conf.getGeometry().getSourceToDetectorDistance(); int numProjMatrices = geo.getNumProjectionMatrices(); float factor = (float) (D * D * Math.PI / numProjMatrices); for (int i = 0; i < stacksize; ++i) { cbFilter.applyToGrid(sino.getSubGrid(i)); for (int j = 0; j < maxV_PX; ++j) ramKRampFilter.applyToGrid(sino.getSubGrid(i).getSubGrid(j)); NumericalPointwiseOperators.multiplyBy(sino.getSubGrid(i), factor); } sino.show("FBP:sinoFiltered"); ConeBeamBackprojector cbbp = new ConeBeamBackprojector(); Grid3D recImage = USE_CL_BP ? cbbp.backprojectPixelDrivenCL(sino) : cbbp .backprojectPixelDriven(sino); return recImage; } */ private static Grid3D reconstructFBP(OpenCLGrid3D grid, double focalLength, double maxU, double maxV, double deltaU, double deltaV, int maxU_PX, int maxV_PX, Configuration conf, Trajectory geo) { boolean shortScan = false; if(geo.getAverageAngularIncrement()*geo.getProjectionStackSize() <=200) shortScan = true ; ConeBeamProjector cbp = new ConeBeamProjector(); ConeBeamBackprojector cbbp = new ConeBeamBackprojector(); TruncationCorrectionTool truncCorr = new TruncationCorrectionTool(); ConeBeamCosineFilter cbFilter = new ConeBeamCosineFilter(focalLength,maxV, maxU, deltaV, deltaU); RamLakKernel ramK = new RamLakKernel(maxU_PX, deltaU); ParkerWeightingTool pWeights = new ParkerWeightingTool(geo); OpenCLGrid3D recImage = new OpenCLGrid3D(new Grid3D(geo.getReconDimensionX(),geo.getReconDimensionY(),geo.getReconDimensionZ())); recImage.setSpacing(geo.getVoxelSpacingX(),geo.getVoxelSpacingY(),geo.getVoxelSpacingZ()); recImage.setOrigin(geo.getOriginX(),geo.getOriginY(),geo.getOriginZ()); OpenCLGrid2D sinoCL = new OpenCLGrid2D(new Grid2D((int)maxU_PX ,(int)maxV_PX)); sinoCL.setSpacing(geo.getPixelDimensionX(),geo.getPixelDimensionY()); sinoCL.setOrigin(0,0); if(truncationCorrection) truncCorr.configure(); if(shortScan){ //parker weighting try { pWeights.configure(); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } } for (int i = 0; i < geo.getProjectionStackSize(); ++i) { //if short scan, do parker weighting cbp.fastProjectRayDrivenCL(sinoCL, grid,i); if(shortScan) { pWeights.setImageIndex(i); sinoCL = new OpenCLGrid2D(pWeights.applyToolToImage(sinoCL)); } cbFilter.applyToGrid(sinoCL); if(truncationCorrection){ sinoCL = truncCorr.applyToolToImage(sinoCL); } if (i == 0) ramK = new RamLakKernel(sinoCL.getWidth(), deltaU); // ramp for (int j = 0; j < maxV_PX; ++j) ramK.applyToGrid(sinoCL.getSubGrid(j)); if(truncationCorrection){ int cutOffPixelSize = ((int)sinoCL.getSize()[0]-(int)maxU_PX)/2; OpenCLGrid2D sinobuf = new OpenCLGrid2D(new Grid2D((int)maxU_PX,(int)maxV_PX)); for(int v= 0; v < maxV_PX; v++){ for(int u= 0; u < maxU_PX; u++){ sinobuf.setAtIndex(u, v, sinoCL.getValue(new int[]{u+cutOffPixelSize-1,v})); } } sinoCL= new OpenCLGrid2D(sinobuf); } cbbp.fastBackprojectPixelDrivenCL(sinoCL,recImage,i); } return recImage; } }