How do I create and plot a confusion matrix for my trained convolutional neural network?

9 views (last 30 days)
Steven Mozarowski
Steven Mozarowski on 26 Nov 2021
Answered: yanqi liu on 2 Dec 2021
I can't seem to create a confusion matrix for my validation accuracy outcome of my trained convolutional neural network. Below is the code I am using, and thanks in advance for any help!
-----------------------------------------------------------------------------------
clear
rng('shuffle')
outputFolder = fullfile('D:\Large_grains\Training_set');
trainDigitData = imageDatastore(outputFolder,'IncludeSubfolders',true,'LabelSource','foldernames');
outputFolder = fullfile('D:\Large_grains\Validation_set');
testDigitData = imageDatastore(outputFolder,'IncludeSubfolders',true,'LabelSource','foldernames');
inputSize = [224 224 3];
augimdsTrain = augmentedImageDatastore(inputSize,trainDigitData,'ColorPreprocessing','gray2rgb');
augimdsValidation = augmentedImageDatastore(inputSize,testDigitData,'ColorPreprocessing','gray2rgb');
numClasses = 9;
problem2; % load ResNet-18
miniBatchSize = 32;
validationFrequency = floor(numel(trainDigitData.Labels)/miniBatchSize);
options = trainingOptions('sgdm',...
'LearnRateSchedule','piecewise',...
'LearnRateDropFactor',0.1,...
'LearnRateDropPeriod',2,...
'MaxEpochs',10,...
'InitialLearnRate',0.001,...
'MiniBatchSize',miniBatchSize,...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',validationFrequency);
convnet = trainNetwork(augimdsTrain,lgraph,options);
[YPred] = classify(convnet,augimdsValidation);
plotconfusion(augimdsValidation.Labels,YPred)
  2 Comments
Steven Mozarowski
Steven Mozarowski on 29 Nov 2021
Thanks for your response, Shivam! I actually managed to have the script produce a confusion matrix earlier today and was meaning to take this post down when I saw your comment!
To answer your questions:
lgraph is a chart that shows information (like validation accuracy, epoch, time elapsed etc.) as training progresses.
The dataset is a pile of starch grain micrographs I had captured using an imaging flow cytometer. The images are organized in folders on a hard drive with 300 training and 200 validation images per species.
Thanks again for reaching out, I really appreciate it!
-Steven

Sign in to comment.

Answers (1)

yanqi liu
yanqi liu on 2 Dec 2021
yes,sir,if want get the data information,may be use
[c,cm,ind,per] = confusion(augimdsValidation.Labels,YPred)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!