package edu.stanford.rsl.tutorial.modelObserver;
import java.util.Arrays;
import edu.stanford.rsl.conrad.data.numeric.Grid2D;
import edu.stanford.rsl.conrad.numerics.SimpleMatrix;
import edu.stanford.rsl.conrad.numerics.SimpleVector;
import edu.stanford.rsl.conrad.utils.VisualizationUtil;
/**
* @author Iris Kellermann
*/
public class ROC {
/**
* Calculates and displays the ROC curve for the given test images. Calculates and displays the SNR.
* @param testImages The object test images.
* @param emptyTestImages The test images without objects.
* @param Ntest The number of test images in each category.
* @param template The template of the observer.
* @param channelMatrix The matrix with the channel images.
*/
public static void ShowROC(Grid2D[] testImages, Grid2D[] emptyTestImages, int Ntest, SimpleVector template, SimpleMatrix channelMatrix)
{
SimpleVector vals = new SimpleVector(2 * Ntest);
// get observer result values
for(int i = 0; i < Ntest; ++i)
{
vals.setElementValue(i,Observer.GetResultValue(testImages[i], template, channelMatrix));
vals.setElementValue(i+Ntest,Observer.GetResultValue(emptyTestImages[i], template, channelMatrix));
}
// compute SNR
// mean
double sumObj = 0;
double sumEmp = 0;
for(int i = 0; i < Ntest; ++i)
{
sumObj += vals.getElement(i);
sumEmp += vals.getElement(i+Ntest);
}
double meanObj = sumObj / Ntest;
double meanEmp = sumEmp / Ntest;
// variance
double sumVarObj = 0;
double sumVarEmp = 0;
for(int i = 0; i < Ntest; ++i)
{
sumVarObj += (meanObj - vals.getElement(i)) * (meanObj - vals.getElement(i));
sumVarEmp += (meanEmp - vals.getElement(i + Ntest)) * (meanEmp - vals.getElement(i + Ntest));
}
double varObj = sumVarObj / Ntest;
double varEmp = sumVarEmp / Ntest;
double SNR = (meanObj - meanEmp) / (Math.sqrt((varObj + varEmp) / 2));
System.out.println(SNR);
SimpleVector t = new SimpleVector(2 * Ntest + 2);
int[] TP = new int[2 * Ntest + 2];
int[] FP = new int[2 * Ntest + 2];
int[] TN = new int[2 * Ntest + 2];
int[] FN = new int[2 * Ntest + 2];
double[] tmp = vals.copyAsDoubleArray();
Arrays.sort(tmp);
t.setElementValue(0, vals.min() - 1);
for (int i=1; i < 2 * Ntest + 1; ++i)
{
t.setElementValue(i, tmp[i-1]);
}
t.setElementValue(2 * Ntest + 1,vals.max() + 1);
// count number of false/true-positives and negatives
for (int i = 0; i < 2 * Ntest + 2; ++i)
{
TP[i] = 0;
for (int j = 0; j < Ntest; ++j){
if (vals.getElement(j) < t.getElement(i))
TP[i]++;
else
FN[i]++;
if (vals.getElement(j + Ntest) < t.getElement(i))
FP[i]++;
else
TN[i]++;
}
}
double[] fpf = new double[2 * Ntest + 2];
double[] sens = new double[2 * Ntest + 2];
for (int i = 0; i < 2 * Ntest + 2; ++i){
fpf[i] = (double)FP[i]/(double)(FP[i]+TN[i]);
sens[i] = (double)TP[i]/(double)(TP[i]+FN[i]);
}
//create ROC plot
VisualizationUtil.createPlot(fpf,sens,"ROC","fpf","sens").show();
}
}
/*
* Copyright (C) 2010-2014 - Iris Kellermann
* CONRAD is developed as an Open Source project under the GNU General Public License (GPL).
*/