package edu.stanford.rsl.tutorial.dmip; import java.util.ArrayList; import java.util.Iterator; import java.util.TreeMap; import java.util.Map.Entry; import edu.stanford.rsl.conrad.data.numeric.Grid1D; import edu.stanford.rsl.conrad.data.numeric.Grid2D; import edu.stanford.rsl.conrad.data.numeric.Grid3D; import edu.stanford.rsl.conrad.geometry.transforms.AffineTransform; import edu.stanford.rsl.conrad.numerics.SimpleMatrix; import edu.stanford.rsl.conrad.utils.ImageUtil; import edu.stanford.rsl.conrad.utils.VisualizationUtil; import edu.stanford.rsl.conrad.numerics.SimpleVector; import edu.stanford.rsl.jpop.FunctionOptimizer; import edu.stanford.rsl.jpop.OptimizableFunction; import edu.stanford.rsl.jpop.OptimizationOutputFunction; import edu.stanford.rsl.jpop.FunctionOptimizer.OptimizationMode; import ij.IJ; import ij.ImageJ; import ij.gui.PlotWindow; /** * Exercise 7 of Diagnostic Medical Image Processing (DMIP) * Using Mutual Information to solve the registration problem * @author Bastian Bier * */ public class Registration2 { public Grid2D reference = null; public Grid2D image = null; public Grid3D movingStack = null; public int iteration = 0; double[] saveParameters; class CostFunction implements OptimizableFunction, OptimizationOutputFunction{ private TreeMap<Integer,Double> resultVisualizer; private PlotWindow resultVisualizerPlot; @Override public void setNumberOfProcessingBlocks(int number) { } @Override public int getNumberOfProcessingBlocks() { return 1; } @Override public double evaluate(double[] x, int block) { int nrChanges = 0; for (int i = 0; i < x.length; i++) { if(saveParameters[i]!=x[i]){ nrChanges++; } } // Define Rotation SimpleMatrix r = new SimpleMatrix(2, 2); double phi2 = x[0] * (2 * Math.PI) / 360; r.setElementValue(0, 0, Math.cos(phi2)); r.setElementValue(0, 1, -Math.sin(phi2)); r.setElementValue(1, 0, Math.sin(phi2)); r.setElementValue(1, 1, Math.cos(phi2)); // Define translation double t_x = x[1]; double t_y = x[2]; Grid2D im_tmp = new Grid2D(image); // Perform rotation/translation SimpleVector t = new SimpleVector(t_x, t_y); AffineTransform affine = new AffineTransform(r, t); im_tmp.applyTransform(affine); if (nrChanges >= 3) { movingStack.setSubGrid(iteration, im_tmp); iteration++; } // Calculate the cost function double cost = calculateMutualInformation(reference, im_tmp); System.arraycopy(x, 0, saveParameters, 0, x.length); return cost; } @Override public void optimizerCallbackFunction(int currIterationNumber, double[] x, double currFctVal, double[] gradientAtX) { // Visualization of cost function value over time if (this.resultVisualizer == null) resultVisualizer = new TreeMap<Integer, Double>(); resultVisualizer.put(currIterationNumber, currFctVal); if (resultVisualizerPlot != null) resultVisualizerPlot.close(); Grid1D out = new Grid1D(resultVisualizer.size()); Iterator<Entry<Integer,Double>> it = resultVisualizer.entrySet().iterator(); while (it.hasNext()) { Entry<Integer,Double> e = it.next(); out.setAtIndex(e.getKey(), e.getValue().floatValue()); } resultVisualizerPlot = VisualizationUtil.createPlot(out.getBuffer()).show(); } } /** * Method to calculate the Mutual Information * @param ref reference image * @param mov moving image * @return negative mutual information */ private double calculateMutualInformation(Grid2D ref, Grid2D mov){ int histSize = 256; // Step 1: Calculate joint histogram SimpleMatrix jointHistogram = calculateJointHistogram(ref, mov); // Step 2: Get histogram for a single image from the joint histogram // a) for the first image SimpleVector histo1 = new SimpleVector(histSize); histo1 = getHistogramFromJointHistogram(jointHistogram); // b) for the second image SimpleVector histo2 = new SimpleVector(histSize); SimpleMatrix jh_t = jointHistogram.transposed(); histo2 = getHistogramFromJointHistogram(jh_t); // Step 3: Calculate the marginal entropies and the joint entropy double entropy_jointHisto = 0; double entropy_histo1 = 0; double entropy_histo2 = 0; for(int i = 0; i < histSize; i++) { if(histo1.getElement(i) != 0) { // TODO: calculate entropy for histogram 1 } if(histo2.getElement(i) != 0) { // TODO: calculate entropy for histogram 2 } } for (int i = 0; i < histSize; i++) { for (int j = 0; j < histSize; j++) { if(jointHistogram.getElement(i, j) != 0) { // TODO: calculate entropy of the joint histogram } } } // make sure to consider the - in from of the sum (Entropy formula) // TODO // TODO // TODO // Step 4: Calculate the mutual information // Note: The mutual information is high for a good match // but we require a minimization problem --> the result is inverted to fit the optimizer double mutual_information = 0; // TODO: calculate the mutual information return mutual_information * 1000; } /** * Method to calculate the joint histogram of two images * @param im1 image1 * @param im2 image2 * @return a SimpleMatrix corresponding to the joint histogram */ private SimpleMatrix calculateJointHistogram(Grid2D im1, Grid2D im2){ // Calculate joint histogram int histSize = 256; SimpleMatrix jH = new SimpleMatrix(histSize, histSize); for (int i = 0; i < histSize; i++) { for (int j = 0; j < histSize; j++) { // TODO } } // Divide by the number of elements in order to get probabilities for (int i = 0; i < histSize; i++) { for (int j = 0; j < histSize; j++) { // TODO } } return jH; } /** * Method to calculate a histogram from a joint histogram * @param jH The joint histogram * @return a SimpleVector corresponding to the marginal histogram */ private SimpleVector getHistogramFromJointHistogram(SimpleMatrix jH){ // Calculate histogram from joint histogram int histSize = 256; SimpleVector hist = new SimpleVector(histSize); hist.zeros(); for(int i = 0; i < histSize; i++) { for(int j = 0; j < histSize; j++) { // TODO: sum up over the columns } } return hist; } private double[] performOptimization(){ FunctionOptimizer fo = new FunctionOptimizer(); fo.setDimension(3); fo.setNdigit(6); fo.setItnlim(50); fo.setMsg(16); fo.setInitialX(new double[]{0,0,0}); fo.setMaxima(new double[]{50,50,50}); fo.setMinima(new double[]{-50,-50,-50}); fo.setOptimizationMode(OptimizationMode.Function); CostFunction cF = new CostFunction(); movingStack = new Grid3D(reference.getWidth(), reference.getHeight(), 1000, false); iteration = 0; saveParameters = new double[]{Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY}; // Optimization visualized ArrayList<OptimizationOutputFunction> visFcts = new ArrayList<OptimizationOutputFunction>(); visFcts.add(cF); fo.setCallbackFunctions(visFcts); double result[] = fo.optimizeFunction(cF); return result; } public static void main(String[] args){ ImageJ ij = new ImageJ(); // Load images // TODO Adjust paths String filename1 = "C:/StanfordRepo/CONRAD/src/edu/stanford/rsl/tutorial/dmip/T1.png"; String filename2 = "C:/StanfordRepo/CONRAD/src/edu/stanford/rsl/tutorial/dmip/Proton.png"; Grid2D image1 = ImageUtil.wrapImagePlus(IJ.openImage(filename1)).getSubGrid(0); Grid2D image2 = ImageUtil.wrapImagePlus(IJ.openImage(filename2)).getSubGrid(0); image1.show("Input Image 1"); image2.show("Input Image 2"); // Set the Origin of the image in its center // The default origin of an image is in its top left corner // Default Origin: [0.0, 0.0] int w = image1.getWidth(); int h = image1.getHeight(); image1.setOrigin(-(w-1) / 2 , -(h-1)/2); image2.setOrigin(-(w-1) / 2 , -(h-1)/2); image1.setSpacing(1); image2.setSpacing(1); // Blurred Images for the registration to avoid local minima during optimization Grid2D image1_blurred = new Grid2D(image1); Grid2D image2_blurred = new Grid2D(image2); IJ.run(ImageUtil.wrapGrid(image1_blurred,""),"Gaussian Blur...", "sigma=4"); IJ.run(ImageUtil.wrapGrid(image2_blurred,""),"Gaussian Blur...", "sigma=4"); Registration2 reg2 = new Registration2(); reg2.reference = image1_blurred; reg2.image = image2_blurred; // Perform Optimization double res[] = reg2.performOptimization(); // Stack for visualization purposes only Grid3D optimizationStepsGrid = new Grid3D(reg2.reference.getWidth(), reg2.reference.getHeight(), reg2.iteration, false); for (int i = 0; i < optimizationStepsGrid.getSize()[2]; i++) { optimizationStepsGrid.setSubGrid(i, reg2.movingStack.getSubGrid(i)); } optimizationStepsGrid.show("Optimization Steps"); // Transform image back SimpleMatrix r = new SimpleMatrix(2,2); Grid2D registeredImage = new Grid2D(image2); double phi = (2 * Math.PI) / 360 * res[0]; r.setElementValue(0, 0, Math.cos(phi)); r.setElementValue(0, 1, -Math.sin(phi)); r.setElementValue(1, 0, Math.sin(phi)); r.setElementValue(1, 1, Math.cos(phi)); SimpleVector t2 = new SimpleVector(res[1], res[2]); AffineTransform affine2 = new AffineTransform(r, t2); registeredImage.applyTransform(affine2); registeredImage.show("Registered Image"); } }