package edu.stanford.rsl.tutorial.weightedtv; import java.io.BufferedWriter; import java.io.File; import java.io.FileWriter; import java.io.IOException; import ij.IJ; import ij.ImageJ; import ij.ImagePlus; import edu.stanford.rsl.tutorial.filters.RamLakKernel; import edu.stanford.rsl.conrad.data.numeric.Grid1D; import edu.stanford.rsl.conrad.data.numeric.Grid2D; import edu.stanford.rsl.conrad.data.numeric.NumericPointwiseOperators; import edu.stanford.rsl.conrad.geometry.Rotations; import edu.stanford.rsl.conrad.geometry.transforms.ScaleRotate; import edu.stanford.rsl.conrad.numerics.SimpleMatrix; import edu.stanford.rsl.conrad.utils.ImageUtil; import edu.stanford.rsl.tutorial.phantoms.Phantom; import edu.stanford.rsl.tutorial.phantoms.SheppLogan; import edu.stanford.rsl.tutorial.fan.FanBeamProjector2D; import edu.stanford.rsl.tutorial.fan.FanBeamBackprojector2D; /** *This is an example to apply weighted total variation (wTV) on 2D fan-beam limited angle tomography. *wTV is according to [1]. *[1]Candes. Enhancing Sparsity by Reweighted l1 Minimization * @author Yixing Huang * */ public class Perform2DWeightedTV { static boolean isOpencl=true; private int maxIter = 500; public int maxTVIter = 10; private double epsilon = 1.0e-10; public double error = 1.0; public double error2; private int iter = 0; private double step=0; //SART public Grid2D phan; public Grid2D recon; private Grid2D sinogram; private Grid1D sinoDiff1D; private Grid2D reconPre; private Grid2D normSinogram; private Grid2D normGrids; private Grid2D localImageUpdate; //weighted TV public TVGradient TVGrad; //public Grid2D TVGradient; //Geometry public int imgSizeX; public int imgSizeY; private double gammaM; private double maxT; private double deltaT; private double focalLength; private double maxBeta; private double deltaBeta; private int spacingX=1,spacingY=1; private FanBeamProjector2D projector; private FanBeamBackprojector2D backProj; public static void main(String[] args) throws IOException { new ImageJ(); Perform2DWeightedTV tv2D=new Perform2DWeightedTV(); tv2D.imgSizeX = 256; tv2D.imgSizeY=tv2D.imgSizeX; tv2D.phan = new SheppLogan(tv2D.imgSizeX, false); tv2D.phan.show("The Phantom"); tv2D.initialGeometry(); tv2D.initialSART(); tv2D.createNormProj(); tv2D.createNormGrids(); //initial wTV tv2D.TVGrad = new TVGradient(tv2D.recon); while (tv2D.error > tv2D.epsilon && tv2D.iter <=tv2D.maxIter) { tv2D.sartIterate(); tv2D.weightedTVIterate(); tv2D.TVGrad.weightMatrixUpdate(tv2D.recon); tv2D.outPutError(); tv2D.iter++; } tv2D.recon.show("Reconstructed Image"); } /** * SART iteration part */ public void sartIterate(){ reconPre=new Grid2D(recon); for (int theta = 0; theta < sinogram.getSize()[1]; theta++) { if(isOpencl) sinoDiff1D= projector.projectRayDriven1DCL(recon, theta);//get the projection of the current updated image at angle theta else sinoDiff1D = projector.projectRayDriven1D(recon, theta); sinoDiff1D=(Grid1D)NumericPointwiseOperators.subtractedBy(sinoDiff1D, sinogram.getSubGrid(theta));//compare with the measured sinogram, get the difference NumericPointwiseOperators.divideBy(sinoDiff1D, normSinogram.getSubGrid(theta)); if(isOpencl) localImageUpdate = backProj.backprojectPixelDriven1DCL(sinoDiff1D, theta); else localImageUpdate = backProj.backprojectPixelDriven1D(sinoDiff1D, theta); double stepSize = -1.0; NumericPointwiseOperators.multiplyBy(localImageUpdate,(float) stepSize); NumericPointwiseOperators.divideBy(localImageUpdate, normGrids); NumericPointwiseOperators.addBy(recon, localImageUpdate); } recon.getGridOperator().removeNegative(recon); } /** * weighted TV gradient descent part */ public void weightedTVIterate() { Grid2D TVGradient; double TV=TVGrad.getWeightedTVvalue(recon); int i=0; double deltaTV=-1.0; while(i<maxTVIter&& deltaTV<0){ double preTV=TV; TVGradient=TVGrad.computewTVGradient(recon); NumericPointwiseOperators.divideBy(TVGradient, (float) TVGrad.maxValue); backTrackingLineSearch(TVGradient);//**************** TV=TVGrad.getWeightedTVvalue(recon); deltaTV=TV-preTV; System.out.println(iter+" i="+i+" wTV="+TV+" step="+step); i++; } } /** * Using back tracking line search algorithm to find the step size for weighted TV gradient descent * @param grad */ private void backTrackingLineSearch(Grid2D grad){//weighted TV double t=1.0,tmin=1.0e-7; double alpha=0.3, beta=0.6; double delta=1.0f,temp1,temp2; double TV=TVGrad.getWeightedTVvalue(recon); double Gradnorm=alpha*grid2DNorm(grad); Grid2D temp=new Grid2D(grad); NumericPointwiseOperators.multiplyBy(temp,(float) t); temp1=TVGrad.getWeightedTVvalue((Grid2D)NumericPointwiseOperators.subtractedBy(recon,temp)); temp2=t*Gradnorm; delta=temp1-TV+temp2; while(delta>0.0f&& t>tmin) { t=t*beta; temp=(Grid2D)grad.clone(); NumericPointwiseOperators.multiplyBy(temp,(float) t); temp1=TVGrad.getWeightedTVvalue((Grid2D)NumericPointwiseOperators.subtractedBy(recon,temp)); // temp2=t*Gradnorm; delta=temp1-TV+temp2; } step=t; NumericPointwiseOperators.subtractBy(recon, temp);; } /** * The limited angle tomography geometry */ public void initialGeometry() { gammaM = 10. * Math.PI / 180; // maxT = length of the detector array maxT =(int)(imgSizeX*1.5f); // deltaT = size of one detector element deltaT = 1.0; // focalLength = focal length focalLength = (maxT / 2.0 - 0.5) * deltaT / Math.tan(gammaM); // maxBeta = maximum rotation angle maxBeta =160. * Math.PI/180.0; // deltaBeta = step size between source positions deltaBeta = 1.0 * Math.PI / 180.0; projector = new FanBeamProjector2D( focalLength, maxBeta, deltaBeta, maxT, deltaT); backProj=new FanBeamBackprojector2D(focalLength,deltaT, deltaBeta, imgSizeX, imgSizeY); } /** * Initial SART grids */ public void initialSART(){ if(isOpencl) sinogram = (Grid2D) projector.projectRayDrivenCL(phan); else sinogram = (Grid2D) projector.projectRayDriven(phan); sinogram.show("sinogram"); backProj.initSinogramParams(sinogram); recon = new Grid2D(imgSizeX, imgSizeY); recon.setSpacing(spacingX, spacingY); NumericPointwiseOperators.fill(recon, 0);// initialization sinoDiff1D = new Grid1D(sinogram.getSize()[0]); sinoDiff1D.setSpacing(deltaT); localImageUpdate = new Grid2D(imgSizeX, imgSizeY); localImageUpdate.setSpacing(spacingX, spacingY); localImageUpdate.setOrigin(imgSizeX/2,imgSizeY/2); } /** * compute the normalization projections */ public void createNormProj(){ //projection normalization weights Grid2D C_phan=new Grid2D(imgSizeX,imgSizeY);//Constant grid with all values as 1; C_phan.setSpacing(spacingX,spacingY); NumericPointwiseOperators.fill(C_phan,1.0f); if(isOpencl) normSinogram=(Grid2D)projector.projectRayDrivenCL(C_phan); else normSinogram=(Grid2D)projector.projectRayDriven(C_phan); NumericPointwiseOperators.addBy(normSinogram,(float) epsilon); } /** * compute normalization grids for backprojection */ public void createNormGrids(){ //backprojection normalization weights Grid1D C_sino1D=new Grid1D(sinogram.getSize()[0]); C_sino1D.setSpacing(deltaT); NumericPointwiseOperators.fill(C_sino1D,1.0f); normGrids=new Grid2D(imgSizeX,imgSizeY); if(isOpencl) normGrids=backProj.backprojectPixelDriven1DCL(C_sino1D, 0); else normGrids=backProj.backprojectPixelDriven1D(C_sino1D, 0); NumericPointwiseOperators.addBy(normGrids,(float) epsilon); } /** * output the difference between too iterations and output the difference from the ground truth * @throws IOException */ public void outPutError() throws IOException{ if (iter % 5== 0){ recon.show(iter+"_th iteration"); } error =meanSquareError(recon, reconPre); error2=meanSquareError(recon, phan); System.out.println(iter+": error=" + error+" error2= "+error2); } /** * mean square error * @param imgGrid1 * @param imgGrid2 * @return */ private double meanSquareError(Grid2D imgGrid1, Grid2D imgGrid2) { double err = 0; Grid2D temp = new Grid2D(imgGrid1); NumericPointwiseOperators.subtractBy(temp, imgGrid2); err = grid2DNorm(temp); err=err / (temp.getSize()[0] * temp.getSize()[1]); return err; } /** * L2 norm * @param imgGrid * @return */ private double grid2DNorm(Grid2D imgGrid) { double d = 0; for (int row = 0; row < imgGrid.getSize()[0]; row++) for (int col = 0; col < imgGrid.getSize()[1]; col++) d = d + imgGrid.getAtIndex(row, col) * imgGrid.getAtIndex(row, col); return d; } }