package edu.stanford.rsl.tutorial.iterative; import java.util.Arrays; import edu.stanford.rsl.conrad.data.numeric.Grid2D; import edu.stanford.rsl.conrad.data.numeric.Grid3D; import edu.stanford.rsl.conrad.data.numeric.NumericGrid; import edu.stanford.rsl.conrad.data.numeric.NumericGridOperator; import edu.stanford.rsl.tutorial.cone.ConeBeamBackprojector; import edu.stanford.rsl.tutorial.cone.ConeBeamProjector; /** * SART reconstruction * * @author Mario Amrehn * */ public class SartCPU implements Sart { protected final Grid3D vol; protected final Grid3D normSino; protected Grid3D oProj = null; protected final float beta; protected float normFactor; protected NumericGridOperator gop = NumericGridOperator.getInstance(); // ----------------------------------------- protected boolean verbose = false; protected boolean debug = false; protected final static boolean USE_CL_FP = true; // GPU acceleration protected final static boolean USE_CL_BP = true; // GPU acceleration // ----------------------------------------- public SartCPU(int[] volDims, double[] spacing, double[] origin, Grid3D oProj, float beta) throws Exception { if (null == oProj) { throw new Exception("SART: No projection data given"); } if (1 > volDims[0] || 1 > volDims[1] || 1 > volDims[2]) { throw new Exception( "SART: Span of each dimension in the volume has to be a natural number"); } // create initial volume filled with zeros vol = new Grid3D(volDims[0], volDims[1], volDims[2]); // vol.setBoundary(boundarySize); vol.setOrigin(origin); vol.setSpacing(spacing); this.normSino = createNormProj(); this.oProj = oProj; /* calculated once for speedup */ this.normFactor = (float) (gop.normL1(oProj) / gop.normL1(normSino)); this.beta = beta; } public SartCPU(Grid3D initialVol, Grid3D sino, float beta) throws Exception { if (null == initialVol) { throw new Exception("SART: No initial volume given"); } if (null == sino) { throw new Exception("SART: No projection data given"); } vol = initialVol; normSino = createNormProj(); this.oProj = sino; /* calculated once for speedup*/ normFactor = (float) (gop.normL1(oProj) / gop.normL1(normSino)); this.beta = beta; } /** * @return Normalized projection data */ protected Grid3D createNormProj() { if (verbose) System.out.println("Create normalized projections"); final Grid3D onesVol = new Grid3D(vol); gop.fill(onesVol, 1.0f); ConeBeamProjector cbp = new ConeBeamProjector(); Grid3D sino = USE_CL_FP ? cbp.projectRayDrivenCL(onesVol) : cbp .projectPixelDriven(onesVol); if(verbose) sino.show("sinoOfNormProjs"); // prevent div by zero float min = gop.min(sino); //sino.show("Sino of norm projection"); if (0 >= min) gop.addBy(sino, -min + 0.1f); return sino; } public void iterate() throws Exception { this.iterate(1); } public void iterate(final int iter) throws Exception { ConeBeamProjector cbp = new ConeBeamProjector(); ConeBeamBackprojector cbbp = new ConeBeamBackprojector(); int numProjs = cbp.getMaxProjections(); for (int i = 0; i < iter; ++i) { boolean[] projIsUsed = new boolean[numProjs]; // default: false int p = 0; // current projection index for (int n = 0; n < numProjs; ++n) { /* edit/init data structures */ projIsUsed[p] = true; Grid2D sino = USE_CL_FP ? cbp.projectRayDrivenCL(vol, p) : cbp .projectPixelDriven(vol, p); sino = gop.transpose(sino); if(debug && 0 < gop.normL1(vol)){ Grid3D sinoTest = USE_CL_FP ? cbp.projectRayDrivenCL(vol) : cbp .projectPixelDriven(vol); Grid2D sinoTestP = sinoTest.getSubGrid(p); sinoTest.show("sinoCL-Test"); sinoTestP.show("sinoCL-Test-Proj:" + p); Grid2D s = new Grid2D(sinoTestP); gop.subtractBySave(s, sino); s.show("sinoCL-Test-Proj-Diff"); System.out.println("Diff L1: " + gop.normL1(s)); } /* if(verbose && 0 != gop.normL1(sino)){ System.out.println("Some actual data found in sino!"); if(0==p%5) sino.show("sino of it:" + i + " proj:" + p); } */ gop.fillInvalidValues(sino, 0); if (verbose) System.out.println(gop.min(sino) + ":" + gop.max(sino)); // TEST //oProj.show("oProj"); Grid2D oProjP = new Grid2D(oProj.getSubGrid(p)); Grid2D normSinoP = new Grid2D(normSino.getSubGrid(p)); // used read-only, cloning not necessary but save gop.multiplyBy(normSinoP, normFactor); /* update step */ // NOTE: upd = (oProj - sino) ./ normSino if (verbose) reportInvalidValues(oProjP, "oProjP"); // Just in case.. if (verbose) reportInvalidValues(sino, "sino"); // Just in case.. should not happen after fillInvalidValues() gop.subtractBy(oProjP, sino); if (verbose) reportInvalidValues(oProjP, "oProjP"); // Just in case.. if (verbose) reportInvalidValues(normSinoP, "normSinoP"); // Just in case.. gop.divideBySave(oProjP, normSinoP); Grid2D upd = oProjP; if (verbose) reportInvalidValues(upd, "for projection " + p); // NOTE: vol = vol + updBP * beta // upd.setOrigin(oProj.getOrigin()[0], oProj.getOrigin()[1]); // // needed after update? // upd.setSpacing(oProj.getSpacing()[0], oProj.getSpacing()[1]); // // needed after update? Grid3D updBP = USE_CL_BP ? cbbp.backprojectPixelDrivenCL(upd, p) : cbbp.backprojectPixelDriven(upd, p); if (verbose) reportInvalidValues(updBP, "updBP"); if(debug){ Grid3D updBPTest = (!USE_CL_BP) ? cbbp.backprojectPixelDrivenCL(upd, p) : cbbp.backprojectPixelDriven(upd, p); if (verbose) reportInvalidValues(updBPTest, "updBPTest"); gop.multiplyBySave(updBPTest, 1); updBPTest.show("updBPTest"); System.out.println("vol: " + Arrays.toString(vol.getOrigin())); System.out.println("updBP: " + Arrays.toString(updBP.getOrigin())); System.out.println("updBPTest: " + Arrays.toString(updBPTest.getOrigin())); } // GridOp.addInPlace(vol, GridOp.mulInPlace(updBP, beta)); gop.multiplyBySave(updBP, beta); if (verbose) reportInvalidValues(updBP, "updBP after mult"); gop.addBy(vol, updBP); if (verbose) reportInvalidValues(vol, "vol after " + i + " SART iterations"); /* * Don't use projections with a small angle to each other * subsequently */ p = (p + numProjs / 3) % numProjs; for (int ii = 1; projIsUsed[p] && ii < numProjs; ++ii) p = (p + 1) % numProjs; } } } protected void reportInvalidValues(NumericGrid upd, String msg) { if (null != msg && msg.length()>0) msg = " " + msg; int invalCount = gop.countInvalidElements(upd); if (0 < invalCount) System.out.println("Invalid values" + msg + ": " + invalCount); float min = gop.min(upd); float max = gop.max(upd); float normL1 = (float) gop.normL1(upd); System.out.println("Info\t" + msg + "\tmin=" + min + "\tmax=" + max + "\tnormL1=" + normL1); } public Grid3D getVol() { gop.fillInvalidValues(vol, 0); return new Grid3D(vol); } }