Add a custom layer in the model while writing the custom loop

1 次查看(过去 30 天)
I ma trying to connect a network (say net) at the front of pretrained network. I am trying to add a custom layer in my net. The ouput of this custom layer will go through the pre-trained network and the loss will be back-propagated. For doing this I wrote a custom training loop. But when I run my code, an error pops up saying
"Custom layer with backward functions are not supported" .
Can you please help with me with how to add the custom layer when writing the custom training loop? Any pointers on this will be appreciated. Thanks

回答(1 个)

Davide Fantin
Davide Fantin 2021-5-24
Defining the backward function is optional. If you do not specify a backward function, and the layer forward functions support dlarray objects, then the software automatically determines the backward function using automatic differentiation. For a list of functions that support dlarray objects, see List of Functions with dlarray Support. Define a custom backward function when you want to:
  • Use a specific algorithm to compute the derivatives.
  • Use operations in the forward functions that do not support dlarray objects.
Hence, writing a backward function might not be necessary in your case.
I understood that you would like to connect 2 networks between each other. In order to connect layers, the layerGraph API is the suggested approach (https://www.mathworks.com/help/deeplearning/ref/nnet.cnn.layergraph.html). You should:
  1. create the layerGraph that you want to attach in front (which may contain custom layers with or without backward function)
  2. extract the layerGraph from the pretrainedNet (using layerGraph function)
  3. connect the two graphs using the connectLayers function (here the doc: https://www.mathworks.com/help/deeplearning/ref/connectlayers.html)
  4. train the network with trainNetwork or with custom traning loops, depending on the network and the flexibility that you need during training.
Hope this helps!

类别

Help CenterFile Exchange 中查找有关 Image Data Workflows 的更多信息

产品


版本

R2019b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by