MATLAB Answers

custom deep learning training loop: gradient computation using dlgradient

2 views (last 30 days)
Niko Picello
Niko Picello on 13 May 2021
Commented: Niko Picello on 14 May 2021
I'm trying to train a CNN with semi supervised learning but i can't evaluate the automatic gradient properly: in particular when i call the function dlgradient (with loss and net.Learnables as parameters) the program invokes other functions inside it and when it's the time of backwardTape (which is also the method that, using other nested functions, is able to compute the gradient) the program fails! it happens that backwardTape is just skipped by the program (actually it gives the output grad, but if i try to step in with the debugger, i can't and it jump to the next line of the code instead); the line is:
grad = backwardTape(tm,{y},{initialAdjoint},x,retainData,false);
in backwardPass.m of the deep learning toolbox. The output grad is just a vector of empty arrays
P.S. the dlnetwork i have created is based on alexnet using transfer learning.
part of the code of interest is:
loss = labeledLoss + unlabeledLoss; %this two statements are inside a training loop
gradients = dlfeval(@computeModelGradients,net,loss);
function gradients = computeModelGradients(network,loss)
gradients = dlgradient(loss,network.Learnables);
%studentNet is a 1x1 dlNetwork of 24 layers (of which 22 are from alexnet
%and the last 2 are a fully connected and a softmax)
%loss is 1x1 dlArray (which contain a double)

Answers (1)

Mohamed Marei
Mohamed Marei on 14 May 2021
I think I ran into a similar problem when attempting to train a ResNet-18-based model for transfer learning, too. I had to hard-code my evaluation and update step which was by no means straightforward.
In your case, you might want to compute the loss inside the call to dlfeval.
function [loss, gradients] = computeModelGradients(network, pred_labelled, tgts_labelled, pred_unlabelled)
labelled_loss = crossentropy(predictions_labelled, targets_labelled); % your loss definition here
unlabelled_loss = myfunction(pred_unlabelled); % your loss function for the unlabeled predictions
loss = labelled_loss + unlabelled_loss;
gradients = dlgradient(loss, network);

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!