How to extract partial derivatives of some specific layer in the back-propagation of a deep learning model?
6 次查看(过去 30 天)
显示 更早的评论
Say I have a deep learning model, and after training I call this model net.
When I input some images into net, I want to have the partial derivatives , where h are the outputs of the relu1 layer (i.e. ) and θ are the parameters of all trainable weights of the layers before relu1.
You can see that h (i.e. the output of relu1) will have a size of . I write the size of the training weights before relu1 as , where would be the set of all trainable parameters of the layers before relu1. Therefore should have the size of .
How can I get in the code? Many thanks!
My current code
%% Load Data
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
numTrainFiles = 50;
[imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles,'randomize');
%% Define Network Architecture
inputSize = [28 28 1];
numClasses = 10;
layers = [
imageInputLayer(inputSize)
convolution2dLayer(5,20,'Name','conv1')
batchNormalizationLayer('Name','bn1')
reluLayer('Name','relu1')
fullyConnectedLayer(numClasses,'Name','fc2')
softmaxLayer('Name','softmax')
classificationLayer];
%% Train Network
options = trainingOptions('sgdm', ...
'MaxEpochs',4, ...
'ValidationData',imdsValidation, ...
'ValidationFrequency',30, ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(imdsTrain,layers,options);
0 个评论
回答(1 个)
Dinesh Yadav
2019-11-26
Hi
Kindly go through the following link and examples in it.
After the reluLayer command you can use dlgradient to compute partial derivatives on the outputs of relu layer.
Hope it helps.
3 个评论
Dinesh Yadav
2019-11-27
I dont think there is a way to do it with dlgradient without using loops . If you want to do it without using loops you will have to write your own custom gradient function.
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Custom Training Loops 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!