package edu.stanford.rsl.tutorial.iterative;
import edu.stanford.rsl.conrad.data.numeric.Grid3D;
/**
* @author Mario Amrehn
*
*/
public final class GridOp {
private GridOp() {
}
// a = a+b
public static Grid3D addInPlace(Grid3D sumA, Grid3D sumB) throws Exception {
sumA.getGridOperator().addBy(sumA, sumB);
return sumA;
}
// a = a + scalar
public static Grid3D addInPlace(Grid3D sumA, float sumB) {
sumA.getGridOperator().addBy(sumA, sumB);
return sumA;
}
// c = a+b
public static Grid3D add(Grid3D sumA, Grid3D sumB) throws Exception {
return add(sumA, sumB, 0, 0, 0);
}
// c = a+b
public static Grid3D add(Grid3D sumA, Grid3D sumB, int aXOff, int aYOff, int aZOff) throws Exception {
int[] sA = sumA.getSize();
int[] sB = sumB.getSize();
if (sA[0] != sB[0] || sA[1] != sB[1] || sA[2] != sB[2])
throw new Exception("GridOperation: Sizes of grids don't match");
Grid3D res = new Grid3D(sumA);
for (int x = aXOff; x < sA[0]+aXOff; ++x)
for (int y = aYOff; y < sA[1]+aYOff; ++y)
for (int z = aZOff; z < sA[2]+aZOff; ++z){
int xIdx = (x >= sA[0] || x < 0) ? Math.min(Math.max(0, x), sA[0]-1) : x;
int yIdx = (y >= sA[1] || y < 0) ? Math.min(Math.max(0, y), sA[1]-1) : y;
int zIdx = (z >= sA[2] || z < 0) ? Math.min(Math.max(0, z), sA[2]-1) : z;
res.setAtIndex(x, y, z, sumA.getAtIndex(xIdx,yIdx,zIdx) + sumB.getAtIndex(x,y,z));
}
res.setSpacing(sumA.getSpacing());
res.setOrigin(sumA.getOrigin());
return res;
}
// c = a + scalar
public static Grid3D add(Grid3D sumA, float sumB) throws Exception {
Grid3D res = new Grid3D(sumA);
res.getGridOperator().addBy(res, sumB);
res.setSpacing(sumA.getSpacing());
res.setOrigin(sumA.getOrigin());
return res;
}
// c = a-b
public static Grid3D sub(Grid3D min, Grid3D sub) throws Exception {
return sub(min, sub, 0, 0, 0, true);
}
// c = a'-b
public static Grid3D sub(Grid3D min, Grid3D sub, int aXOff, int aYOff, int aZOff, final boolean offsetLeft) throws Exception {
int[] sA = min.getSize();
int[] sB = sub.getSize();
if (sA[0] != sB[0] || sA[1] != sB[1] || sA[2] != sB[2])
throw new Exception("GridOperation: Sizes of grids don't match");
Grid3D res = new Grid3D(min);
for (int x = aXOff; x < sA[0]+aXOff; ++x)
for (int y = aYOff; y < sA[1]+aYOff; ++y)
for (int z = aZOff; z < sA[2]+aZOff; ++z){
int xIdx = (x >= sA[0] || x < 0) ? Math.min(Math.max(0, x), sA[0]-1) : x;
int yIdx = (y >= sA[1] || y < 0) ? Math.min(Math.max(0, y), sA[1]-1) : y;
int zIdx = (z >= sA[2] || z < 0) ? Math.min(Math.max(0, z), sA[2]-1) : z;
if(offsetLeft)
res.setAtIndex(x-aXOff,y-aYOff,z-aZOff, min.getAtIndex(xIdx,yIdx,zIdx) - sub.getAtIndex(x-aXOff,y-aYOff,z-aZOff));
else
res.setAtIndex(x-aXOff,y-aYOff,z-aZOff, min.getAtIndex(x-aXOff,y-aYOff,z-aZOff) - sub.getAtIndex(xIdx,yIdx,zIdx));
}
res.setSpacing(min.getSpacing());
res.setOrigin(min.getOrigin());
return res;
}
// a = a-b
public static Grid3D subInPlace(Grid3D min, Grid3D sub) throws Exception {
min.getGridOperator().subtractBy(min, sub);
return min;
}
// // a = a-b'
// public static void subInPlaceForDivergence(Grid3D min, Grid3D sub, int aXOff, int aYOff, int aZOff) throws Exception {
// int[] sA = min.getSize();
// int[] sB = sub.getSize();
// if (sA[0] != sB[0] || sA[1] != sB[1] || sA[2] != sB[2])
// throw new Exception("GridOperation: Sizes of grids don't match");
//
// float[][][] a = min.getBuffer();
// float[][][] b = sub.getBuffer();
// for (int x = aXOff; x < sA[0]+aXOff; ++x)
// for (int y = aYOff; y < sA[1]+aYOff; ++y)
// for (int z = aZOff; z < sA[2]+aZOff; ++z){
// int xIdx = (x >= sA[0] || x < 0) ? Math.min(Math.max(0, x), sA[0]-1) : x;
// int yIdx = (y >= sA[1] || y < 0) ? Math.min(Math.max(0, y), sA[1]-1) : y;
// int zIdx = (z >= sA[2] || z < 0) ? Math.min(Math.max(0, z), sA[2]-1) : z;
// if(0 != aXOff){
// if(0 == xIdx) continue;
// if(sA[0]-1 == xIdx){
// a[x-aXOff][y-aYOff][z-aZOff]
// }
// }
// a[x-aXOff][y-aYOff][z-aZOff] -= b[xIdx][yIdx][zIdx];
// }
// }
// c = a/b (per element)
public static Grid3D div(Grid3D divid, Grid3D divis) throws Exception {
Grid3D res = new Grid3D(divid);
res.getGridOperator().divideBy(res, divis);
res.setSpacing(divid.getSpacing());
res.setOrigin(divid.getOrigin());
return res;
}
public static Grid3D divInPlace(Grid3D divid, Grid3D divis) throws Exception {
divid.getGridOperator().divideBy(divid, divis);
return divid;
}
// c = a * scalar
public static Grid3D mul(Grid3D multA, float multB) {
Grid3D res = new Grid3D(multA);
res.getGridOperator().multiplyBySave(res, multB);
res.setSpacing(multA.getSpacing());
res.setOrigin(multA.getOrigin());
return res;
}
// c = a*b (per element)
public static Grid3D mul(Grid3D multA, Grid3D multB) throws Exception {
Grid3D res = new Grid3D(multA);
res.getGridOperator().multiplyBySave(res, multB);
res.setSpacing(multA.getSpacing());
res.setOrigin(multA.getOrigin());
return res;
}
// a = a*b
public static Grid3D mulInPlace(Grid3D multA, Grid3D multB) throws Exception {
multA.getGridOperator().multiplyBySave(multA, multB);
return multA;
}
// a = a * scalar
public static Grid3D mulInPlace(Grid3D multA, float multB) throws Exception {
multA.getGridOperator().multiplyBySave(multA, multB);
return multA;
}
public static double rmse(Grid3D a, Grid3D b) throws Exception {
return a.getGridOperator().rmse(a, b);
}
public static double l1Norm(Grid3D c) {
return c.getGridOperator().normL1(c);
}
/**
* @param a
* @return minimum value of grid a
*/
public static double min(Grid3D a) {
return a.getGridOperator().min(a);
}
/**
* @param a
* @return maximum value of grid a
*/
public static double max(Grid3D a) {
return a.getGridOperator().max(a);
}
/**
* @param a
* @return element wise square root values of grid a
*/
public static Grid3D sqrtInPlace(Grid3D a) {
a.getGridOperator().pow(a, 2);
return a;
}
public static int numNeg(Grid3D a) {
return a.getGridOperator().countNegativeElements(a);
}
public static double sum(Grid3D c) {
return c.getGridOperator().sumSave(c);
}
///////////////////
// Readability improvements
///////////////////
// a = a^2
public static Grid3D square(Grid3D a) throws Exception {
return mul(a, a);
}
// d = a+b+c
public static Grid3D add(Grid3D a, Grid3D b, Grid3D c) throws Exception {
Grid3D res = add(a, b);
addInPlace(res, c);
return res;
}
// d = a+b+c + s
public static Grid3D add(Grid3D a, Grid3D b, Grid3D c, float s)
throws Exception {
Grid3D res = add(a, b);
addInPlace(res, c);
addInPlace(res, s);
return res;
}
}