package edu.stanford.rsl.tutorial.parallel; import java.io.IOException; import java.nio.FloatBuffer; import java.util.ArrayList; import com.jogamp.opencl.CLBuffer; import com.jogamp.opencl.CLCommandQueue; import com.jogamp.opencl.CLContext; import com.jogamp.opencl.CLDevice; import com.jogamp.opencl.CLImage2d; import com.jogamp.opencl.CLImageFormat; import com.jogamp.opencl.CLKernel; import com.jogamp.opencl.CLProgram; import com.jogamp.opencl.CLImageFormat.ChannelOrder; import com.jogamp.opencl.CLImageFormat.ChannelType; import com.jogamp.opencl.CLMemory.Mem; import edu.stanford.rsl.conrad.data.numeric.Grid1D; import edu.stanford.rsl.conrad.data.numeric.Grid2D; import edu.stanford.rsl.conrad.data.numeric.InterpolationOperators; import edu.stanford.rsl.conrad.geometry.shapes.simple.Box; import edu.stanford.rsl.conrad.geometry.shapes.simple.PointND; import edu.stanford.rsl.conrad.geometry.shapes.simple.StraightLine; import edu.stanford.rsl.conrad.geometry.transforms.Transform; import edu.stanford.rsl.conrad.geometry.transforms.Translation; import edu.stanford.rsl.conrad.numerics.SimpleOperators; import edu.stanford.rsl.conrad.numerics.SimpleVector; import edu.stanford.rsl.conrad.opencl.OpenCLUtil; /** * Implementation of a simple parallel projector. In order to create a * projection, the angular range and the angular sampling and the detector * element size and detector element sampling has to be defined. We show both, a * ray driven and a pixel driven projector. * * See L. Zeng. "Medical Image Reconstruction: A Conceptual tutorial". 2009, page 3 * * * @author Recopra Seminar Summer 2012 * */ public class ParallelProjector2D { double maxTheta, deltaTheta, // [rad] maxS, deltaS; // [mm] int maxThetaIndex, maxSIndex; /** * Sampling of projections is defined in the constructor. * * @param maxTheta the angular range in radians * @param deltaTheta the angular step size in radians * @param maxS the detector size in [mm] * @param deltaS the detector element size in [mm] */ public ParallelProjector2D(double maxTheta, double deltaTheta, double maxS, double deltaS) { this.maxS = maxS; this.maxTheta = maxTheta; this.deltaS = deltaS; this.deltaTheta = deltaTheta; this.maxSIndex = (int) (maxS / deltaS); this.maxThetaIndex = (int) (maxTheta / deltaTheta); } /** * The ray driven solution. * * @param grid the image * @return the sinogram */ public Grid2D projectRayDriven(Grid2D grid) { final double samplingRate = 3.d; // # of samples per pixel Grid2D sino = new Grid2D(new float[maxThetaIndex*maxSIndex], maxSIndex, maxThetaIndex); sino.setSpacing(deltaS, deltaTheta); // set up image bounding box in WC Translation trans = new Translation( -(grid.getSize()[0] * grid.getSpacing()[0])/2, -(grid.getSize()[1] * grid.getSpacing()[1])/2, -1 ); Transform inverse = trans.inverse(); Box b = new Box((grid.getSize()[0] * grid.getSpacing()[0]), (grid.getSize()[1] * grid.getSpacing()[1]), 2); b.applyTransform(trans); for(int e=0; e<maxThetaIndex; ++e){ // compute theta [rad] and angular functions. double theta = deltaTheta * e; double cosTheta = Math.cos(theta); double sinTheta = Math.sin(theta); for (int i = 0; i < maxSIndex; ++i) { // compute s, the distance from the detector edge in WC [mm] double s = deltaS * i - maxS / 2; // compute two points on the line through s and theta // We use PointND for Points in 3D space and SimpleVector for directions. PointND p1 = new PointND(s * cosTheta, s * sinTheta, .0d); PointND p2 = new PointND(-sinTheta + (s * cosTheta), (s * sinTheta) + cosTheta, .0d); // set up line equation StraightLine line = new StraightLine(p1, p2); // compute intersections between bounding box and intersection line. ArrayList<PointND> points = b.intersect(line); // only if we have intersections if (2 != points.size()){ if(points.size() == 0) { line.getDirection().multiplyBy(-1.d); points = b.intersect(line); } if(points.size() == 0) continue; } PointND start = points.get(0); // [mm] PointND end = points.get(1); // [mm] // get the normalized increment SimpleVector increment = new SimpleVector( end.getAbstractVector()); increment.subtract(start.getAbstractVector()); double distance = increment.normL2(); increment.divideBy(distance * samplingRate); double sum = .0; start = inverse.transform(start); // compute the integral along the line. for (double t = 0.0; t < distance * samplingRate; ++t) { PointND current = new PointND(start); current.getAbstractVector().add(increment.multipliedBy(t)); double x = current.get(0) / grid.getSpacing()[0], y = current.get(1) / grid.getSpacing()[1]; if (grid.getSize()[0] <= x + 1 || grid.getSize()[1] <= y + 1 || x < 0 || y < 0) continue; sum += InterpolationOperators.interpolateLinear(grid, x, y); } // normalize by the number of interpolation points sum /= samplingRate; // write integral value into the sinogram. sino.setAtIndex(i, e, (float)sum); } } return sino; } public Grid2D projectRayDrivenCL(Grid2D grid) { boolean debug = true; // create context CLContext context = OpenCLUtil.createContext(); if (debug) System.out.println("Context: " + context); //show OpenCL devices in System CLDevice[] devices = context.getDevices(); if (debug){ for (CLDevice dev: devices) System.out.println(dev); } // select device CLDevice device = context.getMaxFlopsDevice(); if (debug) System.out.println("Device: " + device); int imageSize = grid.getSize()[0] * grid.getSize()[1]; // Length of arrays to process int localWorkSize = Math.min(device.getMaxWorkGroupSize(), 8); // Local work size dimensions int globalWorkSizeT = OpenCLUtil.roundUp(localWorkSize, maxSIndex); // rounded up to the nearest multiple of localWorkSize int globalWorkSizeBeta = OpenCLUtil.roundUp(localWorkSize, maxThetaIndex); // rounded up to the nearest multiple of localWorkSize // load sources, create and build program CLProgram program = null; try { program = context.createProgram(this.getClass().getResourceAsStream("ParallelProjector.cl")) .build(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); System.exit(-1); } // create image from input grid CLImageFormat format = new CLImageFormat(ChannelOrder.INTENSITY, ChannelType.FLOAT); CLBuffer<FloatBuffer> imageBuffer = context.createFloatBuffer(imageSize, Mem.READ_ONLY); // for (int i = 0; i < grid.getSize()[0]; ++i) { // imageBuffer.getBuffer().put(grid.getSubGrid(i).getBuffer()); // } for (int i=0;i<grid.getSize()[1];++i){ for (int j=0;j<grid.getSize()[0];++j) imageBuffer.getBuffer().put(grid.getAtIndex(j, i)); } imageBuffer.getBuffer().rewind(); CLImage2d<FloatBuffer> imageGrid = context.createImage2d( imageBuffer.getBuffer(), grid.getSize()[0], grid.getSize()[1], format); imageBuffer.release(); // create memory for sinogram CLBuffer<FloatBuffer> sinogram = context.createFloatBuffer(maxSIndex * maxThetaIndex, Mem.WRITE_ONLY); // copy params CLKernel kernel = program.createCLKernel("projectRayDriven2DCL"); kernel.putArg(imageGrid).putArg(sinogram) .putArg((float)maxS).putArg((float)deltaS) .putArg((float)maxTheta).putArg((float)deltaTheta) .putArg(maxSIndex).putArg(maxThetaIndex); // TODO: Spacing :) // createCommandQueue CLCommandQueue queue = device.createCommandQueue(); queue .putWriteImage(imageGrid, true) .finish() .put2DRangeKernel(kernel, 0, 0, globalWorkSizeBeta, globalWorkSizeT, localWorkSize, localWorkSize).putBarrier() .putReadBuffer(sinogram, true) .finish(); // write sinogram back to grid2D Grid2D sino = new Grid2D(new float[maxThetaIndex*maxSIndex], maxSIndex, maxThetaIndex); sino.setSpacing(deltaS, deltaTheta); sinogram.getBuffer().rewind(); for (int i = 0; i < sino.getBuffer().length; ++i) { sino.getBuffer()[i] = sinogram.getBuffer().get(); } // clean up queue.release(); imageGrid.release(); sinogram.release(); kernel.release(); program.release(); context.release(); return sino; } /** * The pixel driven solution. * * @param grid the image * @return the sinogram */ public Grid2D projectPixelDriven(Grid2D grid) { Grid2D sino = new Grid2D(new float[maxThetaIndex*maxSIndex], maxSIndex, maxThetaIndex); sino.setSpacing(deltaS, deltaTheta); for (int i = 0; i < maxThetaIndex; i++) { double theta = deltaTheta * i; double cosTheta = Math.cos(theta); double sinTheta = Math.sin(theta); SimpleVector dirDetector = new SimpleVector(cosTheta, sinTheta); // loop over all grid points // x,y are in the grid coordinate system // wx,wy are in the world coordinate system for (int x = 0; x < grid.getSize()[0]; x++) { for (int y = 0; y < grid.getSize()[1]; y++) { float val = (float) (grid.getAtIndex(x, y)/deltaS); val *= grid.getSpacing()[0] * grid.getSpacing()[1]; // assuming isometric pixels double[] w = grid.indexToPhysical(x, y); double wx = w[0], wy = w[1]; // convenience SimpleVector pixel = new SimpleVector(wx, wy); double s = SimpleOperators.multiplyInnerProd(pixel, dirDetector); s += maxS/2; s /= deltaS; Grid1D subgrid = sino.getSubGrid(i); if (subgrid.getSize()[0] <= s + 1 || s < 0) continue; InterpolationOperators .addInterpolateLinear(subgrid, s, val); } } } return sino; } } /* * Copyright (C) 2010-2014 Andreas Maier * CONRAD is developed as an Open Source project under the GNU General Public License (GPL). */