package edu.stanford.rsl.conrad.cuda;
import edu.stanford.rsl.conrad.utils.CONRAD;
import edu.stanford.rsl.conrad.volume3d.FFTVolumeHandle;
import edu.stanford.rsl.conrad.volume3d.Volume3D;
import edu.stanford.rsl.conrad.volume3d.VolumeOperator;
import jcuda.Pointer;
import jcuda.driver.CUdeviceptr;
import jcuda.jcufft.JCufft;
import jcuda.jcufft.cufftHandle;
import jcuda.jcufft.cufftType;
import jcuda.runtime.JCuda;
public class CUDAFFTVolumeHandle extends FFTVolumeHandle {
private boolean nativeCopy = false;
public CUDAFFTVolumeHandle(VolumeOperator operator){
super(operator);
JCufft.setExceptionsEnabled(true);
JCufft.initialize();
}
public enum CUFFTResult {
CUFFT_SUCCESS,
CUFFT_INVALID_PLAN,
CUFFT_ALLOC_FAILED,
CUFFT_INVALID_TYPE,
CUFFT_INVALID_VALUE,
CUFFT_INTERNAL_ERROR,
CUFFT_EXEC_FAILED,
CUFFT_SETUP_FAILED,
CUFFT_INVALID_SIZE
};
public static CUFFTResult getResultEnum(int i){
CUFFTResult res = null;
if (i==0) res = CUFFTResult.CUFFT_SUCCESS;
if (i==1) res = CUFFTResult.CUFFT_INVALID_PLAN;
if (i==2) res = CUFFTResult.CUFFT_ALLOC_FAILED;
if (i==3) res = CUFFTResult.CUFFT_INVALID_TYPE;
if (i==4) res = CUFFTResult.CUFFT_INVALID_VALUE;
if (i==5) res = CUFFTResult.CUFFT_INTERNAL_ERROR;
if (i==6) res = CUFFTResult.CUFFT_EXEC_FAILED;
if (i==7) res = CUFFTResult.CUFFT_SETUP_FAILED;
if (i==8) res = CUFFTResult.CUFFT_INVALID_SIZE;
return res;
}
private static void checkResult(int i) throws Exception {
if (i != 0) {
throw new Exception ("CUDA FFT Error: " + getResultEnum(i));
}
}
public static float[][][] toHostFormat(float [] cuda, int [] size){
float [] [] [] hostVolume = new float[size[0]][size[1]][size[2]*2];
int sliceStride = size[2] * size[1]*2;
int rowStride = size[2]*2;
for (int h = 0; h < size[0]; h++){
for (int j = 0; j < size[1]; j++){
System.arraycopy(cuda, (sliceStride * h) + (j * rowStride), hostVolume[h][j], 0, size[2]*2);
//for (int i = 0; i < size[2]; i++){
// float value = cuda[(sliceStride * h) + (j * rowStride) + (i*2)];
// hostVolume[h][j][i*2] = value;
// value = cuda[(sliceStride * h) + (j * rowStride) + (i*2)+1];;
// hostVolume[h][j][i*2+1] = value;
//}
}
}
return hostVolume;
}
public static float[] toCUDAFormat(float [][][] hostVolume){
int [] size = {hostVolume.length, hostVolume[0].length, hostVolume[0][0].length/2};
int sliceStride = size[2] * size[1]*2;
int rowStride = size[2]*2;
float [] cuda = new float[size[0]*size[1]*size[2]*2];
for (int h = 0; h < size[0]; h++){
for (int j = 0; j < size[1]; j++){
System.arraycopy(hostVolume[h][j], 0, cuda, (sliceStride * h) + (j * rowStride), size[2]*2);
//for (int i = 0; i < size[2]; i++){
// float value = hostVolume[h][j][2*i];
// cuda[(sliceStride * h) + (j * rowStride) + (i*2)] = value;
// value = hostVolume[h][j][2*i+1];
// cuda[(sliceStride * h) + (j * rowStride) + (i*2)+1] = value;
//}
}
}
return cuda;
}
/**
* Performs a forward 3-D FFT on the given volume in the CUDA memory.
* @param deviceX the Pointer to the device's memory
* @param size the sizes of the volume
* @throws Exception may happen.
*/
public void forwardTransform(Pointer deviceX, int [] size) throws Exception {
if (debug){
System.out.println("Planning " + size[0] +"x"+size[1]+"x"+size[2] +" Complex FFT");
}
cufftHandle plan = new cufftHandle();
int revan = JCufft.cufftPlan3d(plan, size[0], size[1], size[2], cufftType.CUFFT_C2C);
checkResult(revan);
revan = JCufft.cufftExecC2C(plan, deviceX, deviceX, JCufft.CUFFT_FORWARD);
checkResult(revan);
revan = JCufft.cufftDestroy(plan);
checkResult(revan);
}
@Override
public void forwardTransform(Volume3D vol)
{
try{
if (debug)
System.out.println("CUDA vol_fft\n");
operator.makeComplex(vol);
if (vol instanceof CUDAVolume3D){
int [] fftsize = {vol.size[0], vol.size[1], vol.size[2]};
forwardTransform(((CUDAVolume3D) vol).getDevicePointer(), fftsize);
} else {
CONRAD.gc();
if(nativeCopy) {
CUdeviceptr deviceX = CUDAUtil.allocateSpace(vol);
CUDAUtil.moveToDevice(vol, deviceX);
forwardTransform(deviceX, vol.size);
CUDAUtil.fetchFromDevice(vol, deviceX);
JCuda.cudaFree(deviceX);
} else {
float [] cuda = toCUDAFormat(vol.data);
cufftHandle plan = new cufftHandle();
int revan = JCufft.cufftPlan3d(plan, vol.size[0], vol.size[1], vol.size[2], cufftType.CUFFT_C2C);
checkResult(revan);
revan = JCufft.cufftExecC2C(plan, cuda, cuda, JCufft.CUFFT_FORWARD);
checkResult(revan);
//Clean up
revan = JCufft.cufftDestroy(plan);
checkResult(revan);
vol.data = null;
vol.data = toHostFormat(cuda, vol.size);
cuda = null;
CONRAD.gc();
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* Performs an inverse 3-D FFT on the CUDA device memory pointed to by deviceX.
* @param deviceX the device pointer
* @param size the sizes of the volume.
* @throws Exception may happen.
*/
public void inverseTransform(Pointer deviceX, int [] size) throws Exception{
cufftHandle plan = new cufftHandle();
int revan = JCufft.cufftPlan3d(plan, size[0], size[1], size[2], cufftType.CUFFT_C2C);
checkResult(revan);
revan = JCufft.cufftExecC2C(plan, deviceX, deviceX, JCufft.CUFFT_INVERSE);
checkResult(revan);
revan = JCufft.cufftDestroy(plan);
checkResult(revan);
}
@Override
public void inverseTransform(Volume3D vol)
{
try{
if (debug)
System.out.println("CUDA vol_ifft\n");
operator.makeComplex(vol);
if (vol instanceof CUDAVolume3D){
int [] fftsize = {vol.size[0], vol.size[1], vol.size[2]};
inverseTransform(((CUDAVolume3D) vol).getDevicePointer(), fftsize);
} else {
CONRAD.gc();
if(nativeCopy) {
CUdeviceptr deviceX = CUDAUtil.allocateSpace(vol);
CUDAUtil.moveToDevice(vol, deviceX);
inverseTransform(deviceX, vol.size);
CUDAUtil.fetchFromDevice(vol, deviceX);
JCuda.cudaFree(deviceX);
} else {
float [] cuda = toCUDAFormat(vol.data);
cufftHandle plan = new cufftHandle();
int revan = JCufft.cufftPlan3d(plan, vol.size[0], vol.size[1], vol.size[2], cufftType.CUFFT_C2C);
checkResult(revan);
revan = JCufft.cufftExecC2C(plan, cuda, cuda, JCufft.CUFFT_INVERSE);
checkResult(revan);
//Clean up
revan = JCufft.cufftDestroy(plan);
checkResult(revan);
vol.data = null;
vol.data = toHostFormat(cuda, vol.size);
cuda = null;
CONRAD.gc();
}
}
operator.multiplyScalar(vol, 1.0f / (float) (vol.size[0]*vol.size[1]*vol.size[2]), 0.0f);
} catch (Exception e) {
e.printStackTrace();
}
}
@Override
public void cleanUp() {
if (nativeCopy) JCuda.cudaThreadExit();
}
@Override
public void setThreadNumber(int number) {
// not gonna happen
}
}
/*
* Copyright (C) 2010-2014 - Andreas Maier
* CONRAD is developed as an Open Source project under the GNU General Public License (GPL).
*/