package org.seqcode.gseutils; import java.util.*; /* from a set of values for positive and negative examples, computes ROC curve areas, returns plots, and the sensitivity and specificity values */ public class ROC { private ArrayList<SS> vals; /* creates a new ROC object with a set of positive and negative examples. The double values are whatever value the threshold can vary over */ public ROC (double[] positive, double[] negative, Comparator c) { double min, max, step; min = positive[0]; max = positive[0]; for (int i = 0; i < positive.length; i++) { if (positive[i] > max) {max = positive[i];} if (positive[i] < min) {min = positive[i];} } for (int i = 0; i < negative.length; i++) { if (negative[i] > max) {max = negative[i];} if (negative[i] < min) {min = negative[i];} } step = (max - min) / 100; vals = new ArrayList<SS>(); int i = 0; for (double threshold = min - step; threshold <= max + step; threshold += step) { int tp = 0, tn = 0, fp = 0, fn = 0; for (int j = 0; j < positive.length; j++) { if (c.compare(positive[j],threshold) == 1) { tp++; } else { fn++; } } for (int j = 0; j < negative.length; j++) { if (c.compare(negative[j],threshold) == 1){ fp++; } else { tn++; } } SS n = new SS(); vals.add(n); n.sens = tp / (tp+fn+.000001); n.spec = tn / (tn+fp+.000001); i++; } } /* 2D ROC. second dimension must have size 2. Takes two comparators. The comparator will be called on the value and the threshold and should return 1 iff the value meets the threshold */ public ROC(double[][] positive, double[][] negative, Comparator c1, Comparator c2) { double minone, mintwo, maxone, maxtwo; minone = positive[0][0]; mintwo = positive[0][1]; maxone = positive[0][0]; maxtwo = positive[0][1]; for (int i = 0; i < positive.length; i++) { if (positive[i][0] > maxone) {maxone = positive[i][0];} if (positive[i][0] < minone) {minone = positive[i][0];} if (positive[i][1] > maxtwo) {maxtwo = positive[i][1];} if (positive[i][1] < mintwo) {mintwo = positive[i][1];} } for (int i = 0; i < negative.length; i++) { if (negative[i][0] > maxone) {maxone = negative[i][0];} if (negative[i][0] < minone) {minone = negative[i][0];} if (negative[i][1] > maxtwo) {maxtwo = negative[i][1];} if (negative[i][1] < mintwo) {mintwo = negative[i][1];} } double step1 = (maxone - minone)/20; double step2 = (maxtwo - mintwo)/20; if (step2 == 0) { step2 = .01; step1 = (maxone - minone) / 400; } vals = new ArrayList<SS>(); System.err.println("Minone=" +minone +" maxone=" + maxone + " step1="+ step1); System.err.println("Mintwo=" +mintwo +" maxtwo=" + maxtwo + " step2="+ step2); for (double t1 = minone; t1 <= maxone; t1 += step1) { for (double t2 = mintwo; t2 <= maxtwo; t2 += step2) { int tp = 0, tn = 0, fp = 0, fn = 0; for (int j = 0; j < positive.length; j++) { if (c1.compare(positive[j][0],t1) == 1 && c2.compare(positive[j][1],t2) == 1) { tp++; } else { fn++; } } for (int j = 0; j < negative.length; j++) { if (c1.compare(negative[j][0],t1) == 1 && c2.compare(negative[j][1],t2) == 1) { fp++; } else { tn++; } } SS n = new SS(); vals.add(n); n.sens = tp / (tp + fn + .000001); n.spec = tn / (tn + fp + .000001); // System.err.println("At t1=" + t1 + " and t2=" + t2 + " sens=" + n.sens + " and spec="+ n.spec); } } } public double getROCArea () { double area = 0; SS vals[] = new SS[this.vals.size()]; this.vals.toArray(vals); Arrays.sort(vals); for (int i = 0 ; i < vals.length - 1; i++) { area += .5 * (vals[i].spec + vals[i+1].spec) * (vals[i+1].sens - vals[i].sens); } if (vals[0].spec > vals[vals.length-1].spec) { area += vals[0].sens * .5 * ( 1 + vals[0].spec); area += (1 - vals[vals.length-1].sens) * .5 * vals[vals.length-1].spec; } else { area += vals[vals.length-1].sens * .5 * ( 1 + vals[vals.length-1].spec); area += (1 - vals[0].sens) * .5 * vals[0].spec; } if (area < .5) { return 1 - area; } else { return area; } } } class SS implements Comparable{ public double sens, spec; public int compareTo(Object o) { double val = ((SS)o).sens; if (sens < val) { return -1; } else if (sens > val) { return 1; } else { return 0; } } }