Main Content

confusion

Classification confusion matrix

Description

example

Tip

To plot a confusion chart for a deep learning workflow, use the confusionchart function.

[c,cm,ind,per] = confusion(targets,outputs) takes target and output matrices, targets and outputs, and returns the confusion value, c, the confusion matrix, cm, a cell array, ind, that contains the sample indices of class i targets classified as class j, and a matrix of percentages, per, where each row summarizes four percentages associated with the i-th class.

Examples

collapse all

This example shows how to generate the confusion matrix of the simpleclass_dataset dataset using the confusion function.

Load the simpleclass_dataset dataset. Define a network and then train it.

[x,t] = simpleclass_dataset;
net = patternnet(10);
net = train(net,x,t);
y = net(x);
[c,cm,ind,per] = confusion(t,y)
a3 =

     1     2     3     1     2     3
     4     5     6     4     5     6
     7     7     7     8     8     8
     9     9     9    10    10    10

Input Arguments

collapse all

Matrix of targets, specified as an S-by-Q matrix, where each column vector contains a single 1 value, with all other elements equal to 0. The index of the value equal to 1 indicates which of the S categories that vector represents.

Matrix of outputs, specified as an S-by-Q matrix, where each column contains values in the range [0,1]. The index of the largest element in the column indicates which of the S categories that vector represents.

Output Arguments

collapse all

Fraction of misclassified samples, returned as a scalar.

Confusion matrix, returned as an S-by-S confusion matrix, where cm(i,j) is the number of samples whose target is the i-th class that was classified as j.

Array of indices, returned as an S-by-S cell array, where ind{i,j} contains the indices of samples with the i-th target class, but j-th output class.

Matrix of percentages, returned as an S-by-4 matrix, where each row summarizes four percentages associated with the i-th class:

per(i,1) false negative rate
          = (false negatives)/(all output negatives)
per(i,2) false positive rate
          = (false positives)/(all output positives)
per(i,3) true positive rate
          = (true positives)/(all output positives)
per(i,4) true negative rate
          = (true negatives)/(all output negatives)

See Also

|

Introduced in R2006a