package edu.stanford.rsl.tutorial.weightedtv;
import ij.IJ;
import ij.ImagePlus;
import edu.stanford.rsl.conrad.data.numeric.Grid1D;
import edu.stanford.rsl.conrad.data.numeric.Grid2D;
import edu.stanford.rsl.conrad.data.numeric.Grid3D;
import edu.stanford.rsl.conrad.data.numeric.NumericGrid;
import edu.stanford.rsl.conrad.data.numeric.NumericPointwiseOperators;
import edu.stanford.rsl.conrad.data.numeric.opencl.OpenCLGrid2D;
import edu.stanford.rsl.conrad.data.numeric.opencl.OpenCLGrid3D;
import edu.stanford.rsl.conrad.utils.ImageUtil;
/**
* 3D TV operators
* @author Yixing Huang
*/
public class TVGradient3D {
public float eps = 0.1f;
public float weps=0.001f;
public Grid3D imgGradient;
public Grid3D tvGradient;
public double maxValue=0.0;
public Grid3D weightMatrix;//weights for weighted TV
//OpenCL
public OpenCLGrid3D imgGradientCL;
public OpenCLGrid3D tvGradientCL;
public OpenCLGrid3D weightMatrixCL;
TVOpenCLGridOperators tvOperators;
private Grid3D onesTemp;
/**
* constructor
* @param img
*/
public TVGradient3D(Grid3D img){
this.imgGradient=new Grid3D(img);
this.tvGradient=new Grid3D(img);
onesTemp=new Grid3D(weightMatrix);
NumericPointwiseOperators.fill(onesTemp, 1.0f);
initialWeightMatrix();
}
/**
* constructor with OpenCL
* @param imgCL
*/
public TVGradient3D(OpenCLGrid3D imgCL)
{
this.tvGradientCL=new OpenCLGrid3D(imgCL);
tvGradientCL.getGridOperator().fill(tvGradientCL, 0);
weightMatrixCL=new OpenCLGrid3D(tvGradientCL);
tvOperators=new TVOpenCLGridOperators();
initialWeightMatrixCL();
}
/**
* constructor
* @param size
*/
public TVGradient3D(int[] size)
{this.tvGradientCL=new OpenCLGrid3D(new Grid3D(size[0],size[1],size[2]));
tvGradientCL.getGridOperator().fill(tvGradientCL, 0);
weightMatrixCL=new OpenCLGrid3D(tvGradientCL);
tvOperators=new TVOpenCLGridOperators();
initialWeightMatrixCL();
}
/**
* initial weight matrix as 1
*/
public void initialWeightMatrix(){
weightMatrix=new Grid3D(tvGradient);
NumericPointwiseOperators.fill(weightMatrix, 1.0f);
}
/**
* initial weight matrix as 1 using OpenCL
*/
public void initialWeightMatrixCL(){
weightMatrixCL.getGridOperator().fill(weightMatrixCL, 1.0f);
}
/**
* compute image gradient
* @param img
*/
public void computeImageGradient(Grid3D img){//Compute the gradient of the img
this.imgGradient=new Grid3D(img);
double Hdiff,Vdiff,Zdiff;
for (int i = 0; i < imgGradient.getSize()[0]; i++) {
for (int j = 0; j < imgGradient.getSize()[1]; j++) {
for(int k=0;k<imgGradient.getSize()[2];k++){
double fij = img.getAtIndex(i, j,k);
double fijl = fij;
double fiju = fij;
double fijt=fij;
if (i > 0)
fijl = img.getAtIndex(i - 1, j,k);
if (j > 0)
fiju = img.getAtIndex(i, j - 1,k);
if(k>0)
fijt = img.getAtIndex(i, j,k-1);
Hdiff=fij-fijl;
Vdiff=fij-fiju;
Zdiff=fij-fijt;
this.imgGradient.setAtIndex(i, j,k, (float)Math.sqrt(Hdiff*Hdiff+Vdiff*Vdiff+Zdiff*Zdiff));
}
}
}
}
/**
* compute image gradient with OpenCL
* @param imgCL
*/
public void computeImageGradientCL(OpenCLGrid3D imgCL){
//TVOpenCLGridOperators.getInstance().compute_img_gradient(imgCL, this.imgGradientCL);
this.imgGradientCL=new OpenCLGrid3D(tvGradientCL);
tvOperators.computeImageGradient(imgCL, imgGradientCL);
}
/**
* update the weight matrix
*/
public void weightMatrixUpdate(){//Update the weights for weighted TV
Grid3D gradient_temp=new Grid3D(this.imgGradient);
NumericPointwiseOperators.addBy(gradient_temp, (float)this.weps);
weightMatrix=(Grid3D)NumericPointwiseOperators.dividedBy(onesTemp,gradient_temp );
}
/**
* update the weight matrix with OpenCL
* @param imgCL
*/
public void weightMatrixCLUpdate(OpenCLGrid3D imgCL){
//this.ComputeGradientCL(imgCL);
TVOpenCLGridOperators.getInstance().computeWeightMatrixUpdate(imgCL, weightMatrixCL, weps);
}
/**
*
* @param imgCL
*/
public void adaptiveWeightMatrixCLUpdate(OpenCLGrid3D imgCL){
//this.ComputeGradientCL(imgCL);
TVOpenCLGridOperators.getInstance().computeImageGradient(imgCL,this.tvGradientCL);
TVOpenCLGridOperators.getInstance().computeAdaptiveWeightMatrixUpdate(tvGradientCL, this.weightMatrixCL, weps);
}
/**
* update weight matrix for anisotropic weighted TV (AwTV)
* @param imgCL
*/
public void anisotropicWeightedTVWeightMatrixCLUpdate(OpenCLGrid3D imgCL){
//this.ComputeGradientCL(imgCL);
TVOpenCLGridOperators.getInstance().computeDirectionalWeightedTVWeightMatrixUpdate(imgCL, weightMatrixCL, weps);
}
/**
* update weight matrix, here only compute image gradient in each XY plane, not include Z direction
* @param imgCL
*/
public void weightMatrixCLUpdate2(OpenCLGrid3D imgCL){//do TV in each Z slice
//this.ComputeGradientCL(imgCL);
TVOpenCLGridOperators.getInstance().computeWeightMatrixUpdate2(imgCL, weightMatrixCL, weps);
}
/**
* compute nonweighted TV value
* @return
*/
public double getTVvalue()//Compute ComputeGradient(Grid2D img)
{
double TV=0.0;
/*for (int i = 0; i < gradient.getSize()[0]; i++)
for (int j = 0; j < gradient.getSize()[1]; j++)
for(int k=0;k<gradient.getSize()[2];k++)
TV+=gradient.getAtIndex(i, j,k);*/
TV=NumericPointwiseOperators.sum(imgGradient);
return TV;
}
/**
* compute weighted TV value
* @return
*/
public double getWeightedTVvalue()//Compute ComputeGradient(Grid2D img)
{
double wTV=0.0;
wTV=NumericPointwiseOperators.sum(NumericPointwiseOperators.multipliedBy(imgGradient, weightMatrix));
return wTV;
}
/**
* get weighted TV value with OpenCL
* @param imgCL
* @return
*/
public double getWeightedTVvalueCL(OpenCLGrid3D imgCL)
{
OpenCLGrid3D tempZSum=new OpenCLGrid3D(new Grid3D(imgCL.getSize()[0],imgCL.getSize()[1],1));
tempZSum.getGridOperator().fill(tempZSum, 0);
tvOperators.getWeightedTV(imgCL, weightMatrixCL,tempZSum);
double wTV=tempZSum.getGridOperator().sum(tempZSum);
tempZSum.release();
return wTV;
}
/**
* get TV value in adaptive weighted TV
* @param imgCL
* @return
*/
public double getWeightedTVvalueCLAdaptive(OpenCLGrid3D imgCL)
{
OpenCLGrid3D tempZSum=new OpenCLGrid3D(new Grid3D(imgCL.getSize()[0],imgCL.getSize()[1],1));
tempZSum.getGridOperator().fill(tempZSum, 0);
tvOperators.getAdaptiveWeightedTV(imgCL, weightMatrixCL,tempZSum);
double wTV=tempZSum.getGridOperator().sum(tempZSum);
tempZSum.release();
return wTV;
}
/**
* compute directional weighted TV with OpenCL, in Y direction gradient has a large weight, B=100 for instance
* @param imgCL
* @return
*/
public double getDirectionalWeightedTVvalueCL(OpenCLGrid3D imgCL)
{
OpenCLGrid3D tempZSum=new OpenCLGrid3D(new Grid3D(imgCL.getSize()[0],imgCL.getSize()[1],1));
tempZSum.getGridOperator().fill(tempZSum, 0);
tvOperators.getDirectionalWeightedTV(imgCL, weightMatrixCL,tempZSum);
double wTV=tempZSum.getGridOperator().sum(tempZSum);
tempZSum.release();
return wTV;
}
/**
* compute weighted TV value with OpenCL, here compute image gradient in each XY plane
* @param imgCL
* @return
*/
public double getWeightedTVvalueCL2(OpenCLGrid3D imgCL)
{
OpenCLGrid3D tempZSum=new OpenCLGrid3D(new Grid3D(imgCL.getSize()[0],imgCL.getSize()[1],1));
tempZSum.getGridOperator().fill(tempZSum, 0);
tvOperators.getWeightedTV2(imgCL, weightMatrixCL,tempZSum);
double wTV=tempZSum.getGridOperator().sum(tempZSum);
tempZSum.release();
return wTV;
}
public Grid3D computeTVgradient(Grid3D img) {
//According to the paper:
//Accurate image reconstruction from few-views and limited-angle data in divergent-beam CT
this.maxValue=0.0f;
for (int i = 0; i < tvGradient.getSize()[0]; i++) {
for (int j = 0; j < tvGradient.getSize()[1]; j++) {
for(int k=0;k<imgGradient.getSize()[2];k++){
double fijk = img.getAtIndex(i, j,k);
double fl = fijk;
double fr = fijk;
double fu = fijk;
double fd = fijk;
double ft=fijk;
double fb=fijk;
double fld = fijk;
double fru = fijk;
double frt=fijk;
double flb=fijk;
double fdt=fijk;
double fub=fijk;
//Not at border
if (i > 0)
fl = img.getAtIndex(i - 1, j,k);
if (i < tvGradient.getSize()[0] - 1)
fr = img.getAtIndex(i + 1, j,k);
if (j > 0)
fu = img.getAtIndex(i, j - 1,k);
if (j < tvGradient.getSize()[1] - 1)
fd = img.getAtIndex(i, j + 1,k);
if(k>0)
ft=img.getAtIndex(i, j, k-1);
if(k<tvGradient.getSize()[2]-1)
fb=img.getAtIndex(i, j, k+1);
if (i > 0 & j < tvGradient.getSize()[1] - 1)
fld = img.getAtIndex(i - 1, j + 1,k);
if (i < tvGradient.getSize()[0] - 1 & j > 0)
fru = img.getAtIndex(i + 1, j - 1,k);
if(k>0 & i<tvGradient.getSize()[0]-1)
frt=img.getAtIndex(i+1, j, k-1);
if(i>0 & k<tvGradient.getSize()[2]-1)
flb=img.getAtIndex(i-1, j, k+1);
if(j<(tvGradient.getSize()[1] - 1) & k>0)
fdt=img.getAtIndex(i, j+1, k-1);
if(k<(tvGradient.getSize()[2] - 1) & j>0)
fub=img.getAtIndex(i, j-1, k+1);
double vij = (3 * fijk - fl - fu-ft)
/ Math.sqrt(eps + (fijk- fl) * (fijk - fl)+ (fijk - fu) * (fijk - fu)+(fijk-ft)*(fijk-ft))
- (fr - fijk)
/ Math.sqrt(eps + (fr - fijk) * (fr - fijk)+ (fr - fru) * (fr - fru)+(fr-frt)*(fr-frt))
- (fd - fijk)
/ Math.sqrt(eps + (fd - fijk) * (fd - fijk)+ (fd - fld)*(fd - fld)+(fd-fdt)*(fd-fdt))
-(fb-fijk)
/Math.sqrt(eps+(fb-flb)*(fb-flb)+(fb-fub)*(fb-fub)+(fb-fijk)*(fb-fijk));
if (Math.abs(vij)>maxValue)
maxValue=Math.abs(vij);
tvGradient.setAtIndex(i, j,k, (float) vij);
}
}
}
return tvGradient;
}
/**
* compute weighted TV gradient
* @param img
* @return
*/
public Grid3D computeWeightedTVGradient(Grid3D img) {//weighted TV gradient
//According to the paper:
//Accurate image reconstruction from few-views and limited-angle data in divergent-beam CT
double vij;
double wr,wd,wb;
for (int i = 0; i < tvGradient.getSize()[0]; i++) {
for (int j = 0; j < tvGradient.getSize()[1]; j++) {
for(int k=0;k<imgGradient.getSize()[2];k++){
double fijk = img.getAtIndex(i, j,k);
double fl = fijk;
double fr = fijk;
double fu = fijk;
double fd = fijk;
double ft=fijk;
double fb=fijk;
double fld = fijk;
double fru = fijk;
double frt=fijk;
double flb=fijk;
double fdt=fijk;
double fub=fijk;
//Not at border
if (i > 0)
fl = img.getAtIndex(i - 1, j,k);
if (i < tvGradient.getSize()[0] - 1)
{
fr = img.getAtIndex(i + 1, j,k);
wr=weightMatrix.getAtIndex(i+1, j, k);}
else
wr=0;
if (j > 0)
fu = img.getAtIndex(i, j - 1,k);
if (j < tvGradient.getSize()[1] - 1){
fd = img.getAtIndex(i, j + 1,k);
wd=weightMatrix.getAtIndex(i, j+1, k);}
else
wd=0;
if(k>0)
ft=img.getAtIndex(i, j, k-1);
if(k<tvGradient.getSize()[2]-1){
fb=img.getAtIndex(i, j, k+1);
wb=weightMatrix.getAtIndex(i, j, k+1);
}
else
wb=0;
if (i > 0 & j < tvGradient.getSize()[1] - 1)
fld = img.getAtIndex(i - 1, j + 1,k);
if (i < tvGradient.getSize()[0] - 1 & j > 0)
fru = img.getAtIndex(i + 1, j - 1,k);
if(k>0 & i<tvGradient.getSize()[0]-1)
frt=img.getAtIndex(i+1, j, k-1);
if(i>0 & k<tvGradient.getSize()[2]-1)
flb=img.getAtIndex(i-1, j, k+1);
if(j<(tvGradient.getSize()[1] - 1) & k>0)
fdt=img.getAtIndex(i, j+1, k-1);
if(k<(tvGradient.getSize()[2] - 1) & j>0)
fub=img.getAtIndex(i, j-1, k+1);
//Not at border
vij = weightMatrix.getAtIndex(i, j, k)*(3 * fijk - fl - fu-ft)
/ Math.sqrt(eps + (fijk- fl) * (fijk - fl)+ (fijk - fu) * (fijk - fu)+(fijk-ft)*(fijk-ft))
- wr* (fr - fijk)
/ Math.sqrt(eps + (fr - fijk) * (fr - fijk)+ (fr - fru) * (fr - fru)+(fr-frt)*(fr-frt))
-wd* (fd - fijk)
/ Math.sqrt(eps + (fd - fijk) * (fd - fijk)+ (fd - fld)*(fd - fld)+(fd-fdt)*(fd-fdt))
-wb*(fb-fijk)
/Math.sqrt(eps+(fb-flb)*(fb-flb)+(fb-fub)*(fb-fub)+(fb-fijk)*(fb-fijk));
if (Math.abs(vij)>maxValue)
maxValue=Math.abs(vij);
tvGradient.setAtIndex(i, j,k, (float) vij);
}
}
}
return tvGradient;
}
/**
* compute weighted TV gradient with OpenCL
* @param imgCL
* @return
*/
public OpenCLGrid3D computeWeightedTVGradient(OpenCLGrid3D imgCL)
{
//TVOpenCLGridOperators.getInstance().compute_wTV_Gradient(imgCL, WmatrixCL, TVgradientCL);
tvOperators.computeWeightedTVGradient(imgCL, weightMatrixCL, tvGradientCL);
return this.tvGradientCL;
}
/**
*
* @param imgCL
* @return
*/
public OpenCLGrid3D computeAdaptiveWeightedTVGradient(OpenCLGrid3D imgCL)
{
//TVOpenCLGridOperators.getInstance().compute_wTV_Gradient(imgCL, WmatrixCL, TVgradientCL);
tvOperators.computeAdaptiveWeightedTVGradient(imgCL, weightMatrixCL, tvGradientCL);
return this.tvGradientCL;
}
/**
* compute anisotropic weighted TV gradient
* @param imgCL
* @return
*/
public OpenCLGrid3D computeAnisotropicWeightedTVGradient(OpenCLGrid3D imgCL)
{
//TVOpenCLGridOperators.getInstance().compute_wTV_Gradient(imgCL, WmatrixCL, TVgradientCL);
tvOperators.computeDirectionalWeightedTVGradient(imgCL, weightMatrixCL, tvGradientCL);
return this.tvGradientCL;
}
/**
* compute wTV gradient, only in XY plane
* @param imgCL
* @return
*/
public OpenCLGrid3D computeWeightedTVGradient2(OpenCLGrid3D imgCL)//do TV in each Z slide
{
//TVOpenCLGridOperators.getInstance().compute_wTV_Gradient(imgCL, WmatrixCL, TVgradientCL);
tvOperators.computeWeightedTVGradient2(imgCL, weightMatrixCL, tvGradientCL);
return this.tvGradientCL;
}
/**
* update weight matrix for anisotropic weighted TV (AwTV) along Y direction
* @param imgCL
*/
public void weightMatrixCLUpdateY(OpenCLGrid3D imgCL){//mainly along Y
//this.ComputeGradientCL(imgCL);
TVOpenCLGridOperators.getInstance().computeWeightMatrixUpdateY(imgCL, weightMatrixCL, weps);
}
/**
* get weighted TV value for anisotropic weighted TV (AwTV) along Y direction
* @param imgCL
* @return
*/
public double getWeightedTVvalueCLY(OpenCLGrid3D imgCL)
{
OpenCLGrid3D tempZSum=new OpenCLGrid3D(new Grid3D(imgCL.getSize()[0],imgCL.getSize()[1],1));
tempZSum.getGridOperator().fill(tempZSum, 0);
tvOperators.getWeightedTVY(imgCL, weightMatrixCL,tempZSum);
double wTV=tempZSum.getGridOperator().sum(tempZSum);
tempZSum.release();
return wTV;
}
/**
* get weighted TV gradient for anisotropic weighted TV (AwTV) along Y direction
* @param imgCL
* @return
*/
public OpenCLGrid3D computeWeightedTVGradientY(OpenCLGrid3D imgCL)
{
//TVOpenCLGridOperators.getInstance().compute_wTV_Gradient(imgCL, WmatrixCL, TVgradientCL);
tvOperators.computeWeightedTVGradientY(imgCL, weightMatrixCL, tvGradientCL);
return this.tvGradientCL;
}
/**
* update weight matrix for anisotropic weighted TV (AwTV) along X direction
* @param imgCL
*/
public void weightMatrixCLUpdateX(OpenCLGrid3D imgCL){//mainly along Y
//this.ComputeGradientCL(imgCL);
TVOpenCLGridOperators.getInstance().computeWeightMatrixUpdateX(imgCL, weightMatrixCL, weps);
}
/**
* get weighted TV value for anisotropic weighted TV (AwTV) along X direction
* @param imgCL
* @return
*/
public double getWeightedTVvalueCLX(OpenCLGrid3D imgCL)
{
OpenCLGrid3D tempZSum=new OpenCLGrid3D(new Grid3D(imgCL.getSize()[0],imgCL.getSize()[1],1));
tempZSum.getGridOperator().fill(tempZSum, 0);
tvOperators.getWeightedTVY(imgCL, weightMatrixCL,tempZSum);
double wTV=tempZSum.getGridOperator().sum(tempZSum);
tempZSum.release();
return wTV;
}
/**
* get weighted TV gradient for anisotropic weighted TV (AwTV) along X direction
* @param imgCL
* @return
*/
public OpenCLGrid3D computeWeightedTVGradientX(OpenCLGrid3D imgCL)
{
//TVOpenCLGridOperators.getInstance().compute_wTV_Gradient(imgCL, WmatrixCL, TVgradientCL);
tvOperators.computeWeightedTVGradientX(imgCL, weightMatrixCL, tvGradientCL);
return this.tvGradientCL;
}
}