Can you plot the gradient for CNNs using trainNetwork?

8 views (last 30 days)
I am using the trainNetwork command to train my network, but noticed that there is no way to plot the gradients over iterations. The trainInfo output does contain some information, but does not seem to contain any information about the gradient.

Answers (1)

Snehal
Snehal on 27 Mar 2025
I understand that you want to extract the gradient information while training a CNN and plot this over iterations. While ‘trainNetwork’ function in MATLAB does not directly expose gradients during the training process, there are two possible workarounds that you can follow:
  • Below is a sample code snippet on extracting gradients using ‘dlgradient’:
net = dlnetwork(layers); % Where ‘layers’ refers to a sequence of layers defined previously in the code.
% Assume 'net', 'XBatch', and 'YBatch' are already defined and 'XBatch' is a dlarray
% Forward pass
YPred = forward(net, XBatch);
% Computing loss
loss = crossentropy(YPred, YBatch);
% Compute gradients
gradients = dlgradient(loss, net.Learnables); % 'gradients' now contains the gradients of the loss with respect to the learnable parameters
  • To plot gradients when using ‘trainNetwork’, you can use a custom plot function instead. Information relating to rate of change of parameters like ‘TrainingLoss’and ‘ValidationLoss’ over iterations can be used to monitor and estimate the gradient-related patterns during training.
Below are some documentation links, you can refer to them for more information:
Hope this helps.

Categories

Find more on Image Data Workflows in Help Center and File Exchange

Products


Release

R2018a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!