package edu.stanford.rsl.tutorial.dmip; import java.util.ArrayList; import edu.stanford.rsl.conrad.data.generic.datatypes.Complex; import edu.stanford.rsl.conrad.data.numeric.Grid2D; import edu.stanford.rsl.conrad.data.numeric.Grid2DComplex; import edu.stanford.rsl.conrad.data.numeric.Grid3D; import edu.stanford.rsl.conrad.data.numeric.NumericGridOperator; import edu.stanford.rsl.conrad.filtering.MedianFilteringTool; import edu.stanford.rsl.conrad.utils.ImageUtil; import ij.IJ; import ij.ImageJ; /** * * Exercise 4 of Diagnostic Medical Image Processing (DMIP) * @author Marco Boegel * */ public class DefectPixelInterpolation { public DefectPixelInterpolation() { } /** * Spectral deconvolution for defect pixel interpolation * "Defect Interpolation in Digital Radiography - How Object-Oriented Transform Coding Helps", * T. Aach, V. Metzler, SPIE Vol. 4322: Medical Imaging, February 2001 * @param image corrupted image with defect pixels * @param mask binary image mask of the defect pixels (0: defect, 1: fine) * @param maxIter maximum number of iterations * @param zeroPadding enables zero padding for the input images. Images are enlarged to size of next power of 2 and filled with zeros * @return image, where defect pixels are interpolated */ public Grid2D interpolateSpectral(Grid2D image, Grid2D mask, int maxIter, boolean zeroPadding) { //padding //TODO //TODO //fourier transform Grid2DComplex G = new Grid2DComplex(1,1);//TODO Grid2DComplex W = new Grid2DComplex(1,1);//TODO //TODO //TODO int[] dim = G.getSize(); int[] halfDim = {dim[0]/2, dim[1]/2}; double maxDeltaE_G_Ratio = Double.POSITIVE_INFINITY; double maxDeltaE_G_Thresh = 1.0e-6; Grid2DComplex FHat = new Grid2DComplex(dim[0], dim[1], false); Grid2DComplex FHatNext = new Grid2DComplex(dim[0], dim[1], false); //Setup visualization double lastEredVal = 0; //Visualize every 100th iteration Grid3D visualize = new Grid3D(image.getWidth(), image.getHeight(), maxIter/100); for(int i = 0; i < maxIter; i++) { //Check for convergence if(maxDeltaE_G_Ratio <= maxDeltaE_G_Thresh) { System.out.println("maxDeltaE_G_Ratio = " + maxDeltaE_G_Ratio); break; } //In the i-th iteration select line pair s1,t1 //which maximizes the energy reduction [Paragraph after Eq. (16) in the paper] double maxDeltaE_G = Double.NEGATIVE_INFINITY; //create arraylist to store lines (in case multiple maxima are found) //TODO ArrayList<Integer[]> sj1 = null; //HINT for(int x = 0; x < dim[0]; x++) { for(int y = 0; y < dim[1]; y++) { float val = G.getAtIndex(x, y) ; if( val > maxDeltaE_G) { //TODO //TODO //TODO }else if(val == maxDeltaE_G) { //TODO } } } //if there were more indices than one with the same max_value, pick a random one of these int idx = (int) Math.floor(Math.random() * sj1.size()); int s1 = sj1.get(idx)[0]; int t1 = sj1.get(idx)[1]; //Calculate the ratio of energy reduction in comparison to the last iteration if(i > 0) { maxDeltaE_G_Ratio = Math.abs(maxDeltaE_G - lastEredVal/maxDeltaE_G); } lastEredVal = maxDeltaE_G; //Compute the corresponding linepair s2, t2: //mirror the positions at halfDim int s2 = 0;//TODO int t2 = 0;//TODO //[Paragraph after Eq. (17) in the paper] int twice_s1 = (2*s1) % dim[0]; int twice_t1 = (2*t1) % dim[1]; //Estimate FHat //4 special cases, where only a single line can be selected: //(0,0), (0, halfHeight), (halfWidth,0), (halfWidth, halfHeight) boolean specialCase = false; if( (s1 == 0 && t1 == 0 ) || (s1 == halfDim[0] && t1 == 0) || (s1 == 0 && t1 == halfDim[1]) || (s1 == halfDim[0] && t1 == halfDim[1])) { System.out.println("Special Case"); specialCase = true; //Eq. 15 //FHat = N*(G(s,t)/W(0,0)) //TODO compute FHatNext, use Complex class //TODO //TODO //TODO Complex res = new Complex();//HINT FHatNext.setRealAtIndex(s1, t1, (float) res.getReal()); FHatNext.setImagAtIndex(s1, t1, (float) res.getImag()); } else { //General case //Compute FHatNext for the general case Eq.9 //TODO //TODO //TODO //TODO //TODO Complex res_s1t1 = new Complex();//HINT Complex res_s2t2 = new Complex();//HINT FHatNext.setRealAtIndex(s1, t1, (float) res_s1t1.getReal()); FHatNext.setImagAtIndex(s1, t1, (float) res_s1t1.getImag()); FHatNext.setRealAtIndex(s2, t2, (float) res_s2t2.getReal()); FHatNext.setImagAtIndex(s2, t2, (float) res_s2t2.getImag()); } //End iteration step by forming the new error spectrum updateErrorSpectrum(G, FHatNext, FHat, W, s1, t1, specialCase); //Get rid of rounding errors //G(t1,s1) and G(t2,s2) should be zero G.setAtIndex(s1, t1, 0); if(!specialCase) { G.setAtIndex(s2, t2, 0); } FHat = new Grid2DComplex(FHatNext); if(i % 100 == 0) { //For visualization, apply IFFT to the estimation Grid2DComplex FHatV = new Grid2DComplex(FHat); FHatV.transformInverse(); Grid2D vis = new Grid2D(image); //Fill in the defect mask pixels with current estimation and remove the zero padding for(int x = 0; x < vis.getWidth(); x++) { for(int y = 0; y < vis.getHeight(); y++) { if(mask.getAtIndex(x, y) == 0) { vis.setAtIndex(x, y, FHatV.getRealAtIndex(x, y)); } } } visualize.setSubGrid(i/100, vis); } } visualize.show(); //Compute the inverse fourier transform of the estimated image FHat.transformInverse(); //Fill in the defect mask pixels with the current estimation and remove the zero padding Grid2D result = new Grid2D(image); //TODO //TODO //TODO //TODO return result; } /** * * Do the convolution of the m-times-n matrix F and W * s,t is the position of the selected line pair, the convolution is simplified in the following way: * G(k1,k2) = F(k1,k2) 'conv' W(k1,k2) * = (F(s,t)W(k1-s,k2-t) + F*(s,t)W(k1+s,k2+t)) / (MN) * where F* is the conjugate complex. * * @param G Fourier transformation of input image * @param FHatNext currently estimated FT of the fixed image * @param FHat previous estimated FT of the fixed image * @param W Fourier transformation of the mask image * @param s1 position of the selected line pair * @param t1 position of the selected line pair * @param specialCase */ private static void updateErrorSpectrum(Grid2DComplex G, Grid2DComplex FHatNext, Grid2DComplex FHat, Grid2DComplex W, int s1, int t1, boolean specialCase) { int[] sz = FHatNext.getSize(); // Accumulation: Update pair (s1,t1),(s2,t2) Complex F_st = new Complex(FHatNext.getRealAtIndex(s1, t1) - FHat.getRealAtIndex(s1, t1), FHatNext.getImagAtIndex(s1, t1) - FHat.getImagAtIndex(s1, t1)); Complex F_st_conj = F_st.getConjugate(); int MN = sz[0] * sz[1]; // Compute the new error spectrum for(int j = 0; j < sz[1]; j++) { for(int i = 0; i < sz[0]; i++) { Complex GVal; if(specialCase) { int xneg = (i - s1) % sz[0]; int yneg = (j - t1) % sz[1]; if(xneg < 0) { xneg = sz[0] + xneg; } if(yneg < 0) { yneg = sz[1] + yneg; } GVal = new Complex(G.getRealAtIndex(i, j), G.getImagAtIndex(i, j)); Complex WNeg = new Complex(W.getRealAtIndex(xneg, yneg), W.getImagAtIndex(xneg, yneg)); GVal.sub( ( F_st.mul(WNeg) ).div(MN) ); } else { int xpos = (i + s1) % sz[0]; int ypos = (j + t1) % sz[1]; int xneg = (i - s1) % sz[0]; int yneg = (j - t1) % sz[1]; if(xneg < 0) { xneg = sz[0] + xneg; } if(yneg < 0) { yneg = sz[1] + yneg; } Complex WPos = new Complex(W.getRealAtIndex(xpos, ypos), W.getImagAtIndex(xpos, ypos)); Complex WNeg = new Complex(W.getRealAtIndex(xneg, yneg), W.getImagAtIndex(xneg, yneg)); GVal = new Complex(G.getRealAtIndex(i, j), G.getImagAtIndex(i, j)); GVal = GVal.sub( ( ( F_st.mul(WNeg) ).add( F_st_conj.mul(WPos) ) ).div(MN) ); } G.setRealAtIndex(i, j, (float) GVal.getReal()); G.setImagAtIndex(i, j, (float) GVal.getImag()); } } } public Grid2D interpolateMedian(Grid2D image, Grid2D defects, int kernelWidth, int kernelHeight) { //Pad the image. Otherwise, the filter will ignore kernelWidth/2 at each side of the image Grid2D paddedImage = new Grid2D(image.getWidth()+kernelWidth, image.getHeight()+kernelHeight); for(int i = 0; i <image.getWidth(); i++) { for(int j = 0; j < image.getHeight(); j++) { float val = image.getAtIndex(i, j); paddedImage.setAtIndex(i+kernelWidth/2, j+kernelHeight/2, val); } } paddedImage.show(); MedianFilteringTool medFilt = new MedianFilteringTool(); medFilt.configure(kernelWidth, kernelHeight); Grid2D medianFiltered = medFilt.applyToolToImage(paddedImage); Grid2D result = new Grid2D(image); for(int i = 0; i < image.getWidth(); i++) { for(int j = 0; j < image.getHeight(); j++) { if(defects.getAtIndex(i, j) == 0) { //medianFilteredImage is larger than original image result.setAtIndex(i, j, medianFiltered.getAtIndex(i+kernelWidth/2, j+kernelHeight/2)); } } } return result; } public static void main(String[] args) { ImageJ ij =new ImageJ(); DefectPixelInterpolation dpi = new DefectPixelInterpolation(); //Load an image from file String filename = "D:/02_lectures/DMIP/exercises/2014/3/testimg.bmp"; String filenameMask = "D:/02_lectures/DMIP/exercises/2014/3/mask.bmp"; Grid2D image = ImageUtil.wrapImagePlus(IJ.openImage(filename)).getSubGrid(0); image.show("Ideal Input Image"); Grid2D mask = ImageUtil.wrapImagePlus(IJ.openImage(filenameMask)).getSubGrid(0); //Set some pixels as defect, elementwise multiply with defect pixel mask Grid2D defectImage = new Grid2D(image); //TODO defectImage.show("Defect Image"); //Spatial Interpolation //Median Filter: int kernelWidth = 20; int kernelHeight = kernelWidth; Grid2D medianFiltered = dpi.interpolateMedian(defectImage, mask, kernelWidth, kernelHeight); medianFiltered.show("Median Filtered Image"); //show difference image |Median - Original| Grid2D absDiffMedian = new Grid2D(image.getWidth(), image.getHeight()); for(int i = 0; i < absDiffMedian.getWidth(); i++) { for(int j = 0; j < absDiffMedian.getHeight(); j++) { float val = Math.abs(medianFiltered.getAtIndex(i, j) - image.getAtIndex(i, j)); absDiffMedian.setAtIndex(i, j, val); } } absDiffMedian.show("|Median - Original|"); //Spectral Interpolation boolean zeroPadding = true; int maxIter = 4000; //TODO Grid2D spectralFiltered = new Grid2D(1,1);//TODO spectralFiltered.show("Spectral Filtered Image"); //show difference image |Spectral - Original| Grid2D absDiffSpectral = new Grid2D(image.getWidth(), image.getHeight()); for(int i = 0; i < absDiffSpectral.getWidth(); i++) { for(int j = 0; j < absDiffSpectral.getHeight(); j++) { float val = Math.abs(spectralFiltered.getAtIndex(i, j) - image.getAtIndex(i, j)); absDiffSpectral.setAtIndex(i, j, val); } } absDiffSpectral.show("|Spectral - Original|"); //show difference between median and spectral Grid2D absDiffSpectralMedian = new Grid2D(image.getWidth(), image.getHeight()); for(int i = 0; i < absDiffSpectralMedian.getWidth(); i++) { for(int j = 0; j < absDiffSpectralMedian.getHeight(); j++) { float val = Math.abs(spectralFiltered.getAtIndex(i, j) - medianFiltered.getAtIndex(i, j)); absDiffSpectralMedian.setAtIndex(i, j, val); } } absDiffSpectralMedian.show("|Spectral - Median|"); } }