package org.streaminer.stream.eval;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* <p>This is a data structure for dealing with label-pair-relations. For statistical
* information on a per-label-base there are methods which return instance(s) of
* {@link TableOfConfusion}. That class has methods to calculate values like Precision, Recall, [...].</p>
*
* <p>Look at http://en.wikipedia.org/wiki/Confusion_matrix for detailed descriptions.</p>
*
* <p>Although there is a method for adding a label or a list of labels to the existing labels, it is highly recommended
* to use it as rare as possible, as it requires that (internally) the array storing the label-combination-counts has to
* be copied to a new array which then replaces the old array. Instead, use the constructor with a list of labels as argument,
* where that list should contain as much labels as possible so that the need for adding labels is at a minimum frequency.
* However, if the labels are not known in advance, using {@link #addLabels(java.util.List)} is preferable to
* {@link #addLabel(java.lang.Object)} as there are fewer array copy operations needed with the first mentioned method.</p>
*
* @param T the type of the labels
* @author Benedikt Kulmann, Lukas Kalabis
* @see TableOfConfusion
*/
public final class ConfusionMatrix<T> {
/**
* The list of all labels. Indices refer to the {@link #confusionMatrix}.
*/
private List<T> labels;
/**
* <p>Array which stores the counts of classification instances (pairs of true and predicted labels)</p>
*
* <p>The first dimension represents the true labels while the second dimension represents the predicted labels.</p>
*/
private long[][] confusionMatrix;
/**
* Creates a new ConfusionMatrix instance with an empty list of labels. Don't use
* this constructor if it is possible to construct a list of possible labels in advance.
*/
public ConfusionMatrix() {
this(new ArrayList<T>());
}
/**
* Creates a new ConfusionMatrix instance.
*
* @param labels Labels to maintain a label-pair-combination-counter for.
*/
public ConfusionMatrix(List<T> labels) {
this.labels = labels;
this.confusionMatrix = new long[labels.size()][labels.size()];
}
/**
* <p>Adds the provided label to the list of labels. Afterwards it is necessary to create
* a new array for the internal counters so the usage of this method is expensive. Try to
* add as much labels as possible at a time by using {@link #addLabels(java.util.List)} or,
* which is even better, at object creation time.</p>
*
* <p>Duplicates are ignored.</p>
*
* @param additionalLabel The label to add to the internal list of labels
*/
public void addLabel(T additionalLabel) {
final List<T> additionalLabelAsList = new ArrayList<T>();
additionalLabelAsList.add(additionalLabel);
addLabels(additionalLabelAsList);
}
/**
* <p>Adds the provided list of labels to the internal list of labels. Afterwards it is necessary to create
* a new array for the internal counters so the usage of this method is expensive. Try to
* add as much labels as possible at object creation time.</p>
*
* <p>Duplicates are ignored.</p>
*
* @param additionalLabels The labels to add to the internal list of labels
*/
public void addLabels(List<T> additionalLabels) {
// construct new list of labels
final List<T> modAdditionalLabels = new ArrayList<T>(additionalLabels);
modAdditionalLabels.removeAll(labels);
labels.addAll(modAdditionalLabels);
// construct new confusion matrix
final long[][] newConfusionMatrix = new long[labels.size()][labels.size()];
for(int i=0; i<confusionMatrix.length; i++) {
System.arraycopy(confusionMatrix[i], 0, newConfusionMatrix[i], 0, confusionMatrix.length);
}
confusionMatrix = newConfusionMatrix;
}
/**
* Returns the list of labels this {@link ConfusionMatrix} maintains counters for.
* @return The list of labels this {@link ConfusionMatrix} maintains counters for.
*/
public List<T> getLabels() {
return labels;
}
/**
* <p>Adds a classification instance (true and predicted label) to this {@link ConfusionMatrix}.</p>
*
* <p>Each label which didn't exist previously will be added within this method automatically (which is
* expensive and to be avoided).</p>
*
* @param truth The true label
* @param prediction The predicted label
*/
public void add(T truth, T prediction) {
int indexOfTruth = labels.indexOf(truth);
if(indexOfTruth == -1) {
indexOfTruth = labels.size();
addLabel(truth);
}
int indexOfPrediction = labels.indexOf(prediction);
if(indexOfPrediction == -1) {
indexOfPrediction = labels.size();
addLabel(prediction);
}
confusionMatrix[indexOfTruth][indexOfPrediction]++;
}
/**
* Returns a map which contains a {@link TableOfConfusion} per label.
*
* @return a map of {@link TableOfConfusion} instances
* @see #getTableOfConfusion(java.lang.Object)
*/
public Map<T, TableOfConfusion> getTablesOfConfusion() {
final Map<T, TableOfConfusion> tablesOfConfusion = new HashMap<T, TableOfConfusion>();
for(T label : labels) {
tablesOfConfusion.put(label, getTableOfConfusion(label));
}
return tablesOfConfusion;
}
/**
* Constructs and returns the {@link TableOfConfusion} for the provided label.
*
* @param label The label to construct a {@link TableOfConfusion} for
* @return A {@link TableOfConfusion} instance for the provided label
*
* @see TableOfConfusion
*/
public TableOfConfusion getTableOfConfusion(T label) {
final TableOfConfusion tableOfConfusion = new TableOfConfusion();
tableOfConfusion.addTruePositive(getTruePositiveCount(label));
tableOfConfusion.addTrueNegative(getTrueNegativeCount(label));
tableOfConfusion.addFalsePositive(getFalsePositiveCount(label));
tableOfConfusion.addFalseNegative(getFalseNegativeCount(label));
return tableOfConfusion;
}
/**
* Calculates and returns the overall accuracy for this confusion matrix in range [0,1].
*
* @return the overall accuracy for this confusion matrix in range [0,1]
*/
public double calculateAccuracy() {
double correct = 0.0;
for(int i=0; i<labels.size(); i++) {
correct += confusionMatrix[i][i];
}
double divisor = 0.0;
for(int i=0; i<labels.size(); i++) {
for(int j=0; j<labels.size(); j++) {
divisor += confusionMatrix[i][j];
}
}
if(divisor == 0) {
return Double.NaN;
} else {
return correct / divisor;
}
}
/**
* Returns the class size as a weight for per-class-calculations.
* @param label The label to get the weight for
* @return the class size as a weight for per-class-calculations.
*/
public long getWeightForLabel(T label) {
final int indexOfLabel = labels.indexOf(label);
long weight = 0L;
for(int i=0; i<labels.size(); i++) {
if(i == indexOfLabel) {
continue;
}
weight += confusionMatrix[indexOfLabel][i];
}
return weight;
}
/**
* Returns the number of "true positive" instances for the provided label.
*
* @param label The label to return the counted number of "true positive" instances for
* @return The number of "true positive" instances for the provided label
*/
private long getTruePositiveCount(T label) {
final int indexOfLabel = labels.indexOf(label);
return confusionMatrix[indexOfLabel][indexOfLabel];
}
/**
* Returns the number of "true negative" instances for the provided label.
*
* @param label The label to return the counted number of "true negative" instances for
* @return The number of "true negative" instances for the provided label
*/
private long getTrueNegativeCount(T label) {
final int indexOfLabel = labels.indexOf(label);
long trueNegativeCount = 0L;
for(int i=0; i<labels.size(); i++) {
if(i == indexOfLabel) {
continue;
}
for(int j=0; j<labels.size(); j++) {
if(j == indexOfLabel) {
continue;
}
trueNegativeCount += confusionMatrix[i][j];
}
}
return trueNegativeCount;
}
/**
* Returns the number of "false positive" instances for the provided label.
*
* @param label The label to return the counted number of "false positive" instances for
* @return The number of "false positive" instances for the provided label
*/
private long getFalsePositiveCount(T label) {
final int indexOfLabel = labels.indexOf(label);
long falsePositiveCount = 0L;
for(int i=0; i<labels.size(); i++) {
if(i == indexOfLabel) {
continue;
} else {
falsePositiveCount += confusionMatrix[indexOfLabel][i];
}
}
return falsePositiveCount;
}
/**
* Returns the number of "false negative" instances for the provided label.
*
* @param label The label to return the counted number of "false negative" instances for
* @return The number of "false negative" instances for the provided label
*/
private long getFalseNegativeCount(T label) {
final int indexOfLabel = labels.indexOf(label);
long falseNegativeCount = 0L;
for(int i=0; i<labels.size(); i++) {
if(i == indexOfLabel) {
continue;
} else {
falseNegativeCount += confusionMatrix[i][indexOfLabel];
}
}
return falseNegativeCount;
}
/**
* {@inheritDoc}
*/
@Override
public String toString() {
final String lineSeparator = System.getProperty("line.separator");
StringBuilder sb = new StringBuilder("ConfusionMatrix (rows=truth,columns=prediction)").append(lineSeparator).append("values:").append(lineSeparator);
for(int i=0; i<labels.size(); i++) {
sb.append(labels.get(i));
for(int j=0; j<labels.size(); j++) {
sb.append(" ").append(confusionMatrix[i][j]);
}
sb.append(lineSeparator);
}
sb.append(lineSeparator).append("results:").append(lineSeparator);
for(T label : labels) {
sb.append(label).append(lineSeparator).append(getTableOfConfusion(label));
}
return sb.toString();
}
public String toHtml(){
StringBuilder b = new StringBuilder( "<table class=\"confusionMatrix\">" );
b.append( "<tr>" );
b.append( "<td colspan=\"2\" rowspan=\"2\" style=\"border: none;\"></td><th colspan=\"" + labels.size() + "\">prediction</th>" );
b.append( "</tr>");
b.append( "<tr>" );
for( T l : labels ){
b.append( "<th>" + l.toString() + "</th>" );
}
b.append( "<th>Precision</th>" );
b.append( "</tr>" );
DecimalFormat fmt = new DecimalFormat( "0.00 %" );
for(int i=0; i<labels.size(); i++) {
T cur = labels.get( i );
b.append( "<tr>" );
if( i == 0 )
b.append( "<th rowspan=\"" + labels.size() + "\">true</th>" );
b.append( "<th>" + labels.get(i) + "</th>" );
Double tp = 0.0d;
Double fp = 0.0d;
for(int j=0; j<labels.size(); j++){
T against = labels.get( j );
if( cur != against )
fp += confusionMatrix[i][j];
else
tp += confusionMatrix[i][j];
b.append(" <td>").append(confusionMatrix[i][j]).append( "</td>" );
}
b.append( "<td><nobr>" + fmt.format( tp / (tp+fp) ) + "</nobr></td>" );
b.append( "</tr>\n" );
}
b.append( "</table>" );
return b.toString();
}
}