The example of Train Network Using Federated Learning that is given in MathWorks documentation is not working
26 views (last 30 days)
Show older comments
I am trying to run the example of Train Network Using Federated Learning given in the MathWorks documentation ( https://in.mathworks.com/help/deeplearning/ug/train-network-using-federated-learning.html ). But I am getting the following error:
Though it showing that Undefined function 'preprocessMiniBatch' . I have included the following function of 'preprocessMiniBatch' as given in the Mathworks documentation
Everytime I run the code I am getting this error. I am unable to understand where I am making a mistake. I am using MATLAB 2023a version, CPU and 8GB RAM. I am looking for a solution since a month now. Someone please help me in solving this problem. I will be grateful to you.
2 Comments
Walter Roberson
on 25 Dec 2023
If you are not already doing so, try putting preprocessMiniBatch into its own .m file
Answers (1)
Harsha Vardhan
on 5 Jan 2024
Hi Debojit Sharma,
I understand that you faced an error while using the Federated Learning example. It appears that you were able to resolve this issue following a comment from community. Later, you wanted to plot a confusion matrix for this example.
Confusion Matrix can be plotted using the ‘ confusionmat’ function. Please check the relevant documentation here - https://www.mathworks.com/help/stats/confusionmat.html
To integrate confusion matrix computation for training and testing phases into your existing federated learning code, you can collect predictions and actual labels from the global model for both training and testing datasets and then use these to compute the confusion matrices. There are other possible ways of calculaing the confusion matrix too.
You can check the code mofifications below.
Just like datastores were created for test and validation data, create a datastore for training data as below.
fileList = [];
labelList = [];
for i = 1:numWorkers
tmp = imdsTestVal{i};
fileList = cat(1,fileList,tmp.Files);
labelList = cat(1,labelList,tmp.Labels);
end
imdsGlobalTestVal = imageDatastore(fileList);
imdsGlobalTestVal.Labels = labelList;
[imdsGlobalTest,imdsGlobalVal] = splitEachLabel(imdsGlobalTestVal,0.5,"randomized");
augimdsGlobalTest = augmentedImageDatastore(inputSize(1:2),imdsGlobalTest);
augimdsGlobalVal = augmentedImageDatastore(inputSize(1:2),imdsGlobalVal);
%% Code for creating a datastore for training data
fileList = [];
labelList = [];
for i = 1:numWorkers
tmp = imdsTrain{i};
fileList = cat(1,fileList,tmp.Files);
labelList = cat(1,labelList,tmp.Labels);
end
imdsGlobalTrainVal = imageDatastore(fileList);
imdsGlobalTrainVal.Labels = labelList;
augimdsGlobalTrain = augmentedImageDatastore(inputSize(1:2),imdsGlobalTrainVal);
Similarly, create a 'minibatchqueue' object for training data.
mbqGlobalVal = minibatchqueue(augimdsGlobalVal, ...
MiniBatchSize=miniBatchSize, ...
MiniBatchFcn=preProcess, ...
MiniBatchFormat=["SSCB",""]);
%Code for creating a minibatchqueue for training data
mbqGlobalTrain = minibatchqueue(augimdsGlobalTrain, ...
MiniBatchSize=miniBatchSize, ...
MiniBatchFcn=preProcess, ...
MiniBatchFormat=["SSCB",""]);
After calculating the accuracy, you can plot the confusion matrices for all the training and testing data as below.
accuracy = computeAccuracy(globalModel,mbqGlobalTest,classes);
%Code for displaying training confusion matrix
trainConfusionMat = createConfusionMatrix(globalModel, mbqGlobalTrain, classes);
figure;
confusionchart(trainConfusionMat);
title('Training Confusion Matrix');
%Code for displaying testing confusion matrix
testConfusionMat = createConfusionMatrix(globalModel, mbqGlobalTest, classes);
figure;
confusionchart(testConfusionMat);
title('Testing Confusion Matrix');
The below function creates a confusion mattrix using the 'confusionmat' MATLAB function.
%function for calculating Confusion Matrix
function confusionMat = createConfusionMatrix(net, mbq, classes)
allYPred = [];
allTTest = [];
shuffle(mbq);
while hasdata(mbq)
[XTest, TTest] = next(mbq);
TTest = onehotdecode(TTest, classes, 1)';
YPred = predict(net, XTest);
YPred = onehotdecode(YPred, classes, 1)';
allTTest = [allTTest; TTest];
allYPred = [allYPred; YPred];
end
confusionMat = confusionmat(categorical(allTTest), categorical(allYPred));
end
Hope this helps in resolving your query!
0 Comments
See Also
Categories
Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!