package edu.stanford.rsl.tutorial.dmip; import java.util.ArrayList; import java.util.Iterator; import java.util.TreeMap; import java.util.Map.Entry; import edu.stanford.rsl.conrad.data.numeric.Grid1D; import edu.stanford.rsl.conrad.data.numeric.Grid2D; import edu.stanford.rsl.conrad.data.numeric.Grid3D; import edu.stanford.rsl.conrad.geometry.transforms.AffineTransform; import edu.stanford.rsl.conrad.numerics.SimpleMatrix; import edu.stanford.rsl.conrad.numerics.SimpleVector; import edu.stanford.rsl.conrad.utils.ImageUtil; import edu.stanford.rsl.conrad.utils.VisualizationUtil; import edu.stanford.rsl.jpop.FunctionOptimizer; import edu.stanford.rsl.jpop.FunctionOptimizer.OptimizationMode; import edu.stanford.rsl.jpop.OptimizableFunction; import edu.stanford.rsl.jpop.OptimizationOutputFunction; import edu.stanford.rsl.tutorial.phantoms.SheppLogan; import ij.IJ; import ij.ImageJ; import ij.gui.PlotWindow; /** * Exercise 7 of Diagnostic Medical Image Processing (DMIP) * Solve the Registration Problem using Sum of Squared Difference * @author Bastian Bier * */ public class Registration3 { public Grid2D reference = null; public Grid2D image = null; public Grid3D movingStack = null; public int iteration = 0; double[] saveParameters; public class CostFunction implements OptimizableFunction, OptimizationOutputFunction{ private TreeMap<Integer,Double> resultVisualizer; private PlotWindow resultVisualizerPlot; @Override public void setNumberOfProcessingBlocks(int number) { } @Override public int getNumberOfProcessingBlocks() { return 1; } /* * This function gets a parameter vector and returns the result of the cost function */ @Override public double evaluate(double[] x, int block) { int nrChanges = 0; for (int i = 0; i < x.length; i++) { if(saveParameters[i]!=x[i]){ nrChanges++; } } // Define Rotation SimpleMatrix r = new SimpleMatrix(2,2); double phi2 = x[0] * (2*Math.PI)/360; r.setElementValue(0, 0, Math.cos(phi2)); r.setElementValue(0, 1, - Math.sin(phi2)); r.setElementValue(1, 0, Math.sin(phi2)); r.setElementValue(1, 1, Math.cos(phi2)); // Define translation double t_x = x[1]; double t_y = x[2]; Grid2D im_tmp = new Grid2D(image); // Perform rotation/translation SimpleVector t = new SimpleVector(t_x,t_y); AffineTransform affine = new AffineTransform(r, t); im_tmp.applyTransform(affine); if(nrChanges>=3){ movingStack.setSubGrid(iteration, im_tmp); iteration++; } // Calculate the cost function double cost = SumOfSquaredDifferences(reference, im_tmp); System.arraycopy(x, 0, saveParameters, 0, x.length); return cost; } @Override public void optimizerCallbackFunction(int currIterationNumber, double[] x, double currFctVal, double[] gradientAtX) { // Visualization of cost function value over time if (this.resultVisualizer == null) resultVisualizer = new TreeMap<Integer, Double>(); resultVisualizer.put(currIterationNumber, currFctVal); if (resultVisualizerPlot != null) resultVisualizerPlot.close(); Grid1D out = new Grid1D(resultVisualizer.size()); Iterator<Entry<Integer,Double>> it = resultVisualizer.entrySet().iterator(); while (it.hasNext()) { Entry<Integer,Double> e = it.next(); out.setAtIndex(e.getKey(), e.getValue().floatValue()); } resultVisualizerPlot = VisualizationUtil.createPlot(out.getBuffer()).show(); } } private double SumOfSquaredDifferences(Grid2D ref, Grid2D imageMoving){ double sum = 0.0; for(int i = 0; i < ref.getWidth(); i++) { for(int j = 0; j < ref.getHeight(); j++) { // TODO: calculate SSD } } return sum/ref.getNumberOfElements(); } public double[] performOptimization(){ // Initialize optimization class FunctionOptimizer fo = new FunctionOptimizer(); fo.setDimension(3); fo.setItnlim(50); movingStack = new Grid3D(reference.getWidth(), reference.getHeight(), 1000, false); iteration = 0; fo.setOptimizationMode(OptimizationMode.Function); fo.setNdigit(8); fo.setMsg(16); fo.setInitialX(new double[]{0,0,0}); fo.setMaxima(new double[]{50,50,50}); fo.setMinima(new double[]{-50,-50,-50}); // Initialize the Costfunction of the optimization CostFunction cf = new CostFunction(); saveParameters = new double[]{Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY}; // Optimization visualized ArrayList<OptimizationOutputFunction> visFcts = new ArrayList<OptimizationOutputFunction>(); visFcts.add(cf); fo.setCallbackFunctions(visFcts); // Perform the optimization with the given cost function double[] result = fo.optimizeFunction(cf); return result; } public static void main(String[] args){ ImageJ ij = new ImageJ(); /////////////////////////////////////////////////////// // Part 1: Apply a transformation on a phantom image // /////////////////////////////////////////////////////// // Create Phantom Grid2D phantom = new SheppLogan(256); // Set the Origin of the image in its center // The default origin of an image is in its top left corner // Default Origin: [0.0, 0.0] int w = phantom.getWidth(); int h = phantom.getHeight(); phantom.setOrigin(-(w-1) / 2 , -(h-1)/2); Grid2D phantom_blurred = new Grid2D(phantom); IJ.run(ImageUtil.wrapGrid(phantom_blurred,""),"Gaussian Blur...", "sigma=3"); phantom.show("Phantom"); // Rotate the phantom by 45° and translate it with t = [20, 1] // Define Rotation and translation SimpleMatrix r = new SimpleMatrix(2,2); // TODO: set phi double phi = 0; // TODO: fill the rotation matrix // TODO // TODO // TODO // TODO // TODO: define translation SimpleVector t = new SimpleVector(0,0); // Initialize transformed phantom Grid2D transformedPhantom = new Grid2D(phantom); Grid2D transformedPhantom_blurred = new Grid2D(phantom_blurred); // Create the affine transformation AffineTransform affine = new AffineTransform(r, t); // Apply the transformation transformedPhantom.applyTransform(affine); transformedPhantom_blurred.applyTransform(affine); transformedPhantom.show("Transformed Phantom"); ///////////////////////////////////// // Part 2: Find the transformation // ///////////////////////////////////// // Registration of the transformed image to the initial phantom Registration3 reg3 = new Registration3(); reg3.reference = phantom_blurred; reg3.image = transformedPhantom_blurred; // Optimization double[] res = reg3.performOptimization(); // Stack for visualization purposes only Grid3D optimizationStepsGrid = new Grid3D(reg3.reference.getWidth(),reg3.reference.getHeight(),reg3.iteration,false); for (int i = 0; i < optimizationStepsGrid.getSize()[2]; i++) { optimizationStepsGrid.setSubGrid(i, reg3.movingStack.getSubGrid(i)); } optimizationStepsGrid.show(); // Transform image back Grid2D backtransformedImage = new Grid2D(transformedPhantom); phi = (2*Math.PI)/360 * res[0]; r.setElementValue(0, 0, Math.cos(phi)); r.setElementValue(0, 1, - Math.sin(phi)); r.setElementValue(1, 0, Math.sin(phi)); r.setElementValue(1, 1, Math.cos(phi)); SimpleVector t2 = new SimpleVector(res[1],res[2]); AffineTransform affine2 = new AffineTransform(r, t2); backtransformedImage.applyTransform(affine2); backtransformedImage.show(); } }