package edu.stanford.rsl.tutorial.weightedtv;
/**
* Here is an example to apply weighted total variation on 3D cone-beam limited angle tomography
* @author Yixing Huang
*/
import java.io.IOException;
import com.jogamp.opencl.CLContext;
import ij.ImageJ;
import edu.stanford.rsl.conrad.data.numeric.Grid2D;
import edu.stanford.rsl.conrad.data.numeric.Grid3D;
import edu.stanford.rsl.conrad.geometry.trajectories.Trajectory;
import edu.stanford.rsl.conrad.opencl.OpenCLUtil;
import edu.stanford.rsl.conrad.phantom.NumericalSheppLogan3D;
import edu.stanford.rsl.conrad.utils.Configuration;
import edu.stanford.rsl.tutorial.cone.ConeBeamProjector;
import edu.stanford.rsl.tutorial.cone.ConeBeamBackprojector;
import edu.stanford.rsl.conrad.data.numeric.opencl.OpenCLGrid2D;
import edu.stanford.rsl.conrad.data.numeric.opencl.OpenCLGrid3D;
public class Perform3DWeightedTV {
private int maxIter=500;
public int maxTVIter=10;
private int showInt=5;
private double eps=1.0e-10;
public double error=1,error2=1;
private int iter=0;
//geometry
public int imgSizeX;
public int imgSizeY;
public int imgSizeZ;
protected Trajectory geo = null;
protected int maxProjs;
protected int width;
protected int height;
protected double spacingX;
protected double spacingY;
protected double spacingZ;
protected double originX;
protected double originY;
protected double originZ;
//
protected CLContext context = null;
//SART
private float beta=0.8f;
public OpenCLGrid3D volCL;//Object
public OpenCLGrid3D reconCL;
private OpenCLGrid3D reconCLpre;//Current reconCL before SART and wTV iteration
private OpenCLGrid3D sinoCL;//measured sinogram
private OpenCLGrid3D updBP;
private OpenCLGrid2D sinoP;
protected OpenCLGrid3D normSino;
protected OpenCLGrid3D normGrid;
//weighted TV
public TVGradient3D TVgrad;
protected double step=0;
private ConeBeamProjector cbp;
private ConeBeamBackprojector cbbp;
public static void main(String[] args) throws Exception {
new ImageJ();
Perform3DWeightedTV tv3D=new Perform3DWeightedTV();
tv3D.initOpenCLDataStructure();
/**
* The ground truth phantom
*/
Grid3D phan = new NumericalSheppLogan3D(tv3D.imgSizeX,tv3D.imgSizeY, tv3D.imgSizeZ).getNumericalSheppLoganPhantom();
//phan.setSpacing(spacingX,spacingY,spacingZ);
tv3D.volCL=new OpenCLGrid3D(phan);
tv3D.getMeasuredSinoCL();
tv3D.initialReconCL();
tv3D.createNormProj();
tv3D.createNormGrid();
tv3D.TVgrad=new TVGradient3D(tv3D.reconCL);
/**
* output the TV value of the ground truth phantom
*/
long t_start=System.currentTimeMillis();
tv3D.TVgrad.computeImageGradientCL(tv3D.volCL);
float pTV=tv3D.volCL.getGridOperator().normL1(tv3D.TVgrad.imgGradientCL);
System.out.println("Perfect TV: "+pTV);
while(tv3D.iter<=tv3D.maxIter&&tv3D.error>tv3D.eps){
tv3D.sartIterate();//SART iteration
tv3D.weightedTVIterate();//wTV iteration
tv3D.outPutError();
tv3D.iter++;
}
long t_end=System.currentTimeMillis();
System.out.println("time is "+(t_end-t_start)/1000.0);
tv3D.reconCL.show("final reconCL");
}
/**
* SART part
* @throws Exception
*/
public void sartIterate() throws Exception{
OpenCLGrid2D tempSino;
reconCLpre=new OpenCLGrid3D(reconCL);
for(int projIndex=0;projIndex<maxProjs;projIndex++){
cbp.fastProjectRayDrivenCL(sinoP, reconCL, projIndex);
tempSino=new OpenCLGrid2D(sinoCL.getSubGrid(projIndex));
sinoP.getGridOperator().subtractBy(sinoP,tempSino);
tempSino.release();
tempSino=new OpenCLGrid2D(normSino.getSubGrid(projIndex));
sinoP.getGridOperator().divideBy(sinoP,tempSino);
tempSino.release();
updBP.getGridOperator().fill(updBP, 0);
cbbp.fastBackprojectPixelDrivenCL(sinoP,updBP, projIndex);
updBP.getGridOperator().multiplyBy(updBP, -beta);
// createNormGrid2(projIndex); //****************************
updBP.getGridOperator().divideBy(updBP, normGrid);
reconCL.getGridOperator().addBy(reconCL, updBP);
}
reconCL.getGridOperator().removeNegative(reconCL);
}
/**
* weighted TV gradient descent part
* @throws IOException
*/
public void weightedTVIterate() throws IOException{
OpenCLGrid3D tv_gradient;
double wTV=TVgrad.getWeightedTVvalueCL(reconCL);
int i=0;
double deltaTV=-1.0;
while(i<maxTVIter){
double preTV=wTV;
tv_gradient=TVgrad.computeWeightedTVGradient(reconCL);
tv_gradient.getGridOperator().divideBy(tv_gradient, tv_gradient.getGridOperator().max(tv_gradient));
backTrackingLineSearch(tv_gradient);
wTV=TVgrad.getWeightedTVvalueCL(reconCL);
deltaTV=wTV-preTV;
//System.out.println("iter="+iter+" i="+i+" L1: "+reconCL.getGridOperator().normL1(TVgrad.imgGradientCL)+" wTV="+wTV+" step="+step);
System.out.println("iter="+iter+" i="+i+" wTV="+wTV+" step="+step);
i++;
}
TVgrad.weightMatrixCLUpdate(reconCL);
}
/**
* using back tracking line search to find the step size for weighted TV gradient descent
* @param tv_grad
*/
private void backTrackingLineSearch( OpenCLGrid3D tvGradient){//weighted TV
double t=1.0,tmin=0.0000001;
double alpha=0.3, beta=0.6;
double delta=1.0f,temp1,temp2;
double TV=TVgrad.getWeightedTVvalueCL(reconCL);
double Gradnorm=alpha*openCLGrid3DNorm(tvGradient);
tvGradient.getGridOperator().multiplyBy(tvGradient,(float) t);
reconCL.getGridOperator().subtractBy(reconCL,tvGradient);
temp1=TVgrad.getWeightedTVvalueCL(reconCL);
temp2=t*Gradnorm;
delta=temp1-TV+temp2;
while(delta>0.0f && t>tmin)
{
t=t*beta;
reconCL.getGridOperator().addBy(reconCL, tvGradient);
tvGradient.getGridOperator().multiplyBy(tvGradient,(float) beta);
reconCL.getGridOperator().subtractBy(reconCL,tvGradient);
temp1=TVgrad.getWeightedTVvalueCL(reconCL);
temp2=t*Gradnorm;
delta=temp1-TV+temp2;
}
//System.out.println("t="+t);
step=t;
}
/**
* L2 norm of OpenCLGrid3D
* @param reconCL
* @return
*/
private double openCLGrid3DNorm(OpenCLGrid3D reconCL) {
OpenCLGrid3D tempReconCL=new OpenCLGrid3D(reconCL);
tempReconCL.getGridOperator().multiplyBy(tempReconCL,tempReconCL);
double n=reconCL.getGridOperator().sum(tempReconCL);
tempReconCL.release();
return n;
}
/**
* mean square error
* @param gridCL1
* @param gridCL2
* @return
*/
private double meanSquareError(OpenCLGrid3D gridCL1, OpenCLGrid3D gridCL2){
OpenCLGrid3D tempCL=new OpenCLGrid3D(gridCL1);
tempCL.getGridOperator().subtractBy(tempCL, gridCL2);
double err=openCLGrid3DNorm(tempCL)/(gridCL1.getSize()[0]*gridCL1.getSize()[1]*gridCL1.getSize()[2]);
tempCL.release();
return err;
}
/**
* initialize the geometry
*/
public void initOpenCLDataStructure(){
Configuration.loadConfiguration();
Configuration conf = Configuration.getGlobalConfiguration();
geo = conf.getGeometry();
width = geo.getDetectorWidth();
height = geo.getDetectorHeight();
maxProjs = geo.getProjectionStackSize();
// create context
context = OpenCLUtil.getStaticContext();
imgSizeX = geo.getReconDimensionX();
imgSizeY = geo.getReconDimensionY();
imgSizeZ = geo.getReconDimensionZ();
spacingX = geo.getVoxelSpacingX();
spacingY = geo.getVoxelSpacingY();
spacingZ = geo.getVoxelSpacingZ();
originX = -geo.getOriginX();
originY = -geo.getOriginY();
originZ = -geo.getOriginZ();
cbp=new ConeBeamProjector();
cbbp=new ConeBeamBackprojector();
sinoP=new OpenCLGrid2D(new Grid2D(width,height));
updBP = new OpenCLGrid3D(new Grid3D(imgSizeX,imgSizeY,imgSizeZ));
}
/**
* get the measured sinogram
* @throws Exception
*/
public void getMeasuredSinoCL() throws Exception
{
sinoCL = new OpenCLGrid3D(new Grid3D(width,height,maxProjs));
sinoCL.getDelegate().prepareForDeviceOperation();
cbp.fastProjectRayDrivenCL(sinoCL,volCL);
sinoCL.show("sinoCL");
}
/**
* initialize the reconstructed image as 0
*/
public void initialReconCL()
{
reconCL=new OpenCLGrid3D(new Grid3D(imgSizeX,imgSizeY,imgSizeZ));
reconCL.setSpacing(spacingX,spacingY,spacingZ);
reconCL.getGridOperator().fill(reconCL, 0);
}
/**
* compute the normalization projections
* @throws Exception
*/
public void createNormProj() throws Exception{
OpenCLGrid3D onesVol = new OpenCLGrid3D(new Grid3D(imgSizeX,imgSizeY,imgSizeZ));
onesVol.getGridOperator().fill(onesVol, 1);
normSino = new OpenCLGrid3D(new Grid3D(width,height,maxProjs));
cbp.fastProjectRayDrivenCL(normSino,onesVol);
normSino.getGridOperator().addBy(normSino,(float) eps);
//if(debug)normSino.show("normSino");
}
/**
* compute the projIndex_th normalization grids for backprojection
* @param projIndex
* @throws Exception
*/
protected void createNormGrid(int projIndex) throws Exception {
OpenCLGrid2D c_sinoCL=new OpenCLGrid2D(new Grid2D(width,height));
c_sinoCL.getGridOperator().fill(c_sinoCL, 1.0f);
normGrid.release();
normGrid=new OpenCLGrid3D(volCL);
normGrid.getGridOperator().fill(normGrid, 0);
cbbp.backprojectPixelDrivenCL(normGrid, c_sinoCL, projIndex);
//normGrid.getGridOperator().fill(normGrid, (float)normGrid.getGridOperator().sum(normGrid)/(normGrid.getNumberOfElements()));
normGrid.getGridOperator().addBy(normGrid, (float)eps);
}
/**
* compute the mean normalization grid for backprojection
* @throws Exception
*/
public void createNormGrid() throws Exception {
OpenCLGrid2D c_sinoCL=new OpenCLGrid2D(new Grid2D(width,height));
c_sinoCL.getGridOperator().fill(c_sinoCL, 1.0f);
normGrid=new OpenCLGrid3D(volCL);
normGrid.getGridOperator().fill(normGrid, 0);
cbbp.backprojectPixelDrivenCL(normGrid, c_sinoCL, 0);
normGrid.getGridOperator().fill(normGrid, (float)normGrid.getGridOperator().sum(normGrid)/(normGrid.getNumberOfElements()));
//normGrid.getGridOperator().addBy(normGrid, (float)eps);
}
/**
* out put the result at each iteration
* @throws IOException
*/
public void outPutError() throws IOException{
if (iter % showInt== 0)
{
reconCL.show(iter+"_th iteration");
}
error =meanSquareError(reconCL, reconCLpre);
error2=meanSquareError(reconCL, volCL);
System.out.println(iter+": error=" + error+" error2= "+error2);
reconCLpre.release();
}
}