How can we use Nadam optimizer in place of sgdm in training deep learning networks
4 次查看(过去 30 天)
显示 更早的评论
Training_Options = trainingOptions('sgdm', ...
'MiniBatchSize', 32, ...
'MaxEpochs', 50, ...
"InitialLearnRate", 1e-5, ...
'Shuffle', 'every-epoch', ...
'ValidationData', Resized_Validation_Data, ...
'ValidationFrequency', 40, ...
'Plots','training-progress', ...
0 个评论
回答(2 个)
Joss Knight
You cannot do this using trainNetwork. You need to use a dlnetwork with a custom training loop so you can author your own update rule. Perhaps adam will work for you instead.
0 个评论
Amanjit Dulai
You can train with Nadam by defining a custom training loop. The function dlupdate can be used to define custom update rules for training. The rules for Nadam are shown below:
where the momentum is given by:
Below is an example of how to train a digit classification network using Nadam in a custom training loop:
% Load the data
[XTrain, TTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);
% Define the architecture
numClasses = numel(categories(TTrain));
net = dlnetwork([
imageInputLayer([28 28 1], Normalization="none")
convolution2dLayer(5, 20)
% Set training options
numEpochs = 4;
miniBatchSize = 100;
learnRate = 0.001;
gradientDecay = 0.9;
squaredGradientDecay = 0.99;
momentumDecay = 0.004;
epsilon = 1e-08;
momentums = gradientDecay*(1 - 0.5*0.96^momentumDecay);
velocity = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
squaredGradients = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
% Create mini-batch queue
mbq = minibatchqueue(dsTrain, ...
MiniBatchSize = miniBatchSize,...
MiniBatchFcn = @preprocessMiniBatch,...
MiniBatchFormat = {'SSCB',''});
% Use acceleration to speed up training
acceleratedFcn = dlaccelerate(@modelLoss);
% Initialize the training progress monitor
monitor = trainingProgressMonitor( ...
Metrics = "Loss", ...
Info = "Epoch", ...
XLabel = "Iteration");
% Train the network
numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
iteration = 1;
for epoch = 1:numEpochs
% Shuffle data
while hasdata(mbq) && ~monitor.Stop
% Read mini-batch of data.
[XBatch, TBatch] = next(mbq);
% Evaluate the model gradients, state, and loss.
[loss, gradients, state] = dlfeval(acceleratedFcn, net, XBatch, TBatch);
net.State = state;
% Update the dlnetwork according to Nadam
nextMomentum = gradientDecay*(1 - 0.5*0.96^((iteration + 1)*momentumDecay));
momentums = [momentums nextMomentum]; %#ok<AGROW>
velocity = dlupdate(@(v,g)gradientDecay.*v + (1 - gradientDecay).*g, velocity, gradients);
squaredGradients = dlupdate(@(n,g)squaredGradientDecay.*n + (1 - squaredGradientDecay).*(g.^2), squaredGradients, gradients);
velocityHat = dlupdate(@(v,g)(momentums(iteration+1) .* v) ./ (1-prod(momentums(1:(iteration+1)))) + ...
((1-momentums(iteration)) .* g) ./ (1-prod(momentums(1:iteration))), ...
velocity, gradients);
squaredGradientsHat = dlupdate(@(n)n ./ (1 - squaredGradientDecay.^iteration), squaredGradients);
net.Learnables = dlupdate(@(w,v,n)w - (learnRate .* v) ./ (sqrt(n) + epsilon), ...
net.Learnables, ...
velocityHat, ...
squaredGradientsHat );
% Update the training progress monitor.
recordMetrics(monitor, iteration, Loss = loss);
updateInfo(monitor, Epoch = epoch);
monitor.Progress = 100 * iteration/numIterations;
iteration = iteration +1;
% Calculate the test accuracy
[XTest, TTest] = digitTest4DArrayData;
accuracy = testnet(net, XTest, TTest,"accuracy");
%% Helpers
function [loss, gradients, state] = modelLoss(net, X, T)
[Y, state] = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss, net.Learnables);
function [X,T] = preprocessMiniBatch(XCell,TCell)
X = cat(4,XCell{1:end});
T = cat(2,TCell{1:end});
T = onehotencode(T,1);
1 个评论
Amanjit Dulai
Also, if you want to use weight decay only on the weights, you can modify the example as shown below:
% Load the data
[XTrain, TTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);
% Define the architecture
numClasses = numel(categories(TTrain));
net = dlnetwork([
imageInputLayer([28 28 1], Normalization="none")
convolution2dLayer(5, 20)
% Set training options
numEpochs = 4;
miniBatchSize = 100;
learnRate = 0.001;
gradientDecay = 0.9;
squaredGradientDecay = 0.99;
momentumDecay = 0.004;
epsilon = 1e-08;
l2RegularizationFactor = 0.0001;
momentums = gradientDecay*(1 - 0.5*0.96^momentumDecay);
velocity = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
squaredGradients = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
l2Indices = ~(net.Learnables.Parameter == "Bias");
% Create mini-batch queue
mbq = minibatchqueue(dsTrain, ...
MiniBatchSize = miniBatchSize,...
MiniBatchFcn = @preprocessMiniBatch,...
MiniBatchFormat = {'SSCB',''});
% Use acceleration to speed up training
acceleratedFcn = dlaccelerate(@modelLoss);
% Initialize the training progress monitor
monitor = trainingProgressMonitor( ...
Metrics = "Loss", ...
Info = "Epoch", ...
XLabel = "Iteration");
% Train the network
numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
iteration = 1;
for epoch = 1:numEpochs
% Shuffle data
while hasdata(mbq) && ~monitor.Stop
% Read mini-batch of data.
[XBatch, TBatch] = next(mbq);
% Evaluate the model gradients, state, and loss.
[loss, gradients, state] = dlfeval(acceleratedFcn, net, XBatch, TBatch);
net.State = state;
% Apply weight regulatization
gradients(l2Indices,:) = dlupdate( @(g,w)g + l2RegularizationFactor*w, ...
gradients(l2Indices,:), net.Learnables(l2Indices,:) );
% Update the dlnetwork according to Nadam
nextMomentum = gradientDecay*(1 - 0.5*0.96^((iteration + 1)*momentumDecay));
momentums = [momentums nextMomentum]; %#ok<AGROW>
velocity = dlupdate(@(v,g)gradientDecay.*v + (1 - gradientDecay).*g, velocity, gradients);
squaredGradients = dlupdate(@(n,g)squaredGradientDecay.*n + (1 - squaredGradientDecay).*(g.^2), squaredGradients, gradients);
velocityHat = dlupdate(@(v,g)(momentums(iteration+1) .* v) ./ (1-prod(momentums(1:(iteration+1)))) + ...
((1-momentums(iteration)) .* g) ./ (1-prod(momentums(1:iteration))), ...
velocity, gradients);
squaredGradientsHat = dlupdate(@(n)n ./ (1 - squaredGradientDecay.^iteration), squaredGradients);
net.Learnables = dlupdate(@(w,v,n)w - (learnRate .* v) ./ (sqrt(n) + epsilon), ...
net.Learnables, ...
velocityHat, ...
squaredGradientsHat );
% Update the training progress monitor.
recordMetrics(monitor, iteration, Loss = loss);
updateInfo(monitor, Epoch = epoch);
monitor.Progress = 100 * iteration/numIterations;
iteration = iteration +1;
% Calculate the test accuracy
[XTest, TTest] = digitTest4DArrayData;
accuracy = testnet(net, XTest, TTest,"accuracy");
%% Helpers
function [loss, gradients, state] = modelLoss(net, X, T)
[Y, state] = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss, net.Learnables);
function [X,T] = preprocessMiniBatch(XCell,TCell)
X = cat(4,XCell{1:end});
T = cat(2,TCell{1:end});
T = onehotencode(T,1);
One thing to note is that with adaptive learning rules like Adam and Nadam, it has been found that it is often more effective to apply weight decay directly to the weights instead of the gradients. When applying this to Nadam, it results in the algorithm NadamW. Below is an example on how to use NadamW.
% Load the data
[XTrain, TTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);
% Define the architecture
numClasses = numel(categories(TTrain));
net = dlnetwork([
imageInputLayer([28 28 1], Normalization="none")
convolution2dLayer(5, 20)
% Set training options
numEpochs = 4;
miniBatchSize = 100;
learnRate = 0.001;
gradientDecay = 0.9;
squaredGradientDecay = 0.99;
momentumDecay = 0.004;
epsilon = 1e-08;
l2RegularizationFactor = 0.0001;
momentums = gradientDecay*(1 - 0.5*0.96^momentumDecay);
velocity = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
squaredGradients = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
l2Indices = ~(net.Learnables.Parameter == "Bias");
% Create mini-batch queue
mbq = minibatchqueue(dsTrain, ...
MiniBatchSize = miniBatchSize,...
MiniBatchFcn = @preprocessMiniBatch,...
MiniBatchFormat = {'SSCB',''});
% Use acceleration to speed up training
acceleratedFcn = dlaccelerate(@modelLoss);
% Initialize the training progress monitor
monitor = trainingProgressMonitor( ...
Metrics = "Loss", ...
Info = "Epoch", ...
XLabel = "Iteration");
% Train the network
numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
iteration = 1;
for epoch = 1:numEpochs
% Shuffle data
while hasdata(mbq) && ~monitor.Stop
% Read mini-batch of data.
[XBatch, TBatch] = next(mbq);
% Evaluate the model gradients, state, and loss.
[loss, gradients, state] = dlfeval(acceleratedFcn, net, XBatch, TBatch);
net.State = state;
% Apply decoupled weight regulatization (NadamW)
net.Learnables(l2Indices,:) = dlupdate( @(w)w - learnRate*l2RegularizationFactor*w, ...
net.Learnables(l2Indices,:) );
% Update the dlnetwork according to Nadam
nextMomentum = gradientDecay*(1 - 0.5*0.96^((iteration + 1)*momentumDecay));
momentums = [momentums nextMomentum]; %#ok<AGROW>
velocity = dlupdate(@(v,g)gradientDecay.*v + (1 - gradientDecay).*g, velocity, gradients);
squaredGradients = dlupdate(@(n,g)squaredGradientDecay.*n + (1 - squaredGradientDecay).*(g.^2), squaredGradients, gradients);
velocityHat = dlupdate(@(v,g)(momentums(iteration+1) .* v) ./ (1-prod(momentums(1:(iteration+1)))) + ...
((1-momentums(iteration)) .* g) ./ (1-prod(momentums(1:iteration))), ...
velocity, gradients);
squaredGradientsHat = dlupdate(@(n)n ./ (1 - squaredGradientDecay.^iteration), squaredGradients);
net.Learnables = dlupdate(@(w,v,n)w - (learnRate .* v) ./ (sqrt(n) + epsilon), ...
net.Learnables, ...
velocityHat, ...
squaredGradientsHat );
% Update the training progress monitor.
recordMetrics(monitor, iteration, Loss = loss);
updateInfo(monitor, Epoch = epoch);
monitor.Progress = 100 * iteration/numIterations;
iteration = iteration +1;
% Calculate the test accuracy
[XTest, TTest] = digitTest4DArrayData;
accuracy = testnet(net, XTest, TTest,"accuracy");
%% Helpers
function [loss, gradients, state] = modelLoss(net, X, T)
[Y, state] = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss, net.Learnables);
function [X,T] = preprocessMiniBatch(XCell,TCell)
X = cat(4,XCell{1:end});
T = cat(2,TCell{1:end});
T = onehotencode(T,1);
在 Help Center 和 File Exchange 中查找有关 Image Data Workflows 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!