Add a custom layer in the model while writing the custom loop

56 views (last 30 days)
Aggraj  Gupta
Aggraj Gupta on 21 May 2021
Commented: Aggraj Gupta on 2 Jun 2021
I ma trying to connect a network (say net) at the front of pretrained network. I am trying to add a custom layer in my net. The ouput of this custom layer will go through the pre-trained network and the loss will be back-propagated. For doing this I wrote a custom training loop. But when I run my code, an error pops up saying
"Custom layer with backward functions are not supported" .
Can you please help with me with how to add the custom layer when writing the custom training loop? Any pointers on this will be appreciated. Thanks

Answers (1)

Davide Fantin
Davide Fantin on 24 May 2021
Defining the backward function is optional. If you do not specify a backward function, and the layer forward functions support dlarray objects, then the software automatically determines the backward function using automatic differentiation. For a list of functions that support dlarray objects, see List of Functions with dlarray Support. Define a custom backward function when you want to:
  • Use a specific algorithm to compute the derivatives.
  • Use operations in the forward functions that do not support dlarray objects.
Hence, writing a backward function might not be necessary in your case.
I understood that you would like to connect 2 networks between each other. In order to connect layers, the layerGraph API is the suggested approach (https://www.mathworks.com/help/deeplearning/ref/nnet.cnn.layergraph.html). You should:
  1. create the layerGraph that you want to attach in front (which may contain custom layers with or without backward function)
  2. extract the layerGraph from the pretrainedNet (using layerGraph function)
  3. connect the two graphs using the connectLayers function (here the doc: https://www.mathworks.com/help/deeplearning/ref/connectlayers.html)
  4. train the network with trainNetwork or with custom traning loops, depending on the network and the flexibility that you need during training.
Hope this helps!

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!