sgdmupdate
Update parameters using stochastic gradient descent with momentum (SGDM)
Syntax
Description
Update the network learnable parameters in a custom training loop using the stochastic gradient descent with momentum (SGDM) algorithm.
Note
This function applies the SGDM optimization algorithm to update network parameters in
custom training loops. To train a neural network using the trainnet
function
using the SGDM solver, use the trainingOptions
function and set the solver to
"sgdm"
.
[
updates the learnable parameters of the network netUpdated
,vel
] = sgdmupdate(net
,grad
,vel
)net
using the SGDM
algorithm. Use this syntax in a training loop to iteratively update a network defined as a
dlnetwork
object.
Examples
Update Learnable Parameters Using sgdmupdate
Perform a single SGDM update step with a global learning rate of
0.05
and momentum of 0.95
.
Create the parameters and parameter gradients as numeric arrays.
params = rand(3,3,4); grad = ones(3,3,4);
Initialize the parameter velocities for the first iteration.
vel = [];
Specify custom values for the global learning rate and momentum.
learnRate = 0.05; momentum = 0.95;
Update the learnable parameters using sgdmupdate
.
[params,vel] = sgdmupdate(params,grad,vel,learnRate,momentum);
Train Network Using sgdmupdate
Use sgdmupdate
to train a network using the SGDM algorithm.
Load Training Data
Load the digits training data.
[XTrain,TTrain] = digitTrain4DArrayData; classes = categories(TTrain); numClasses = numel(classes);
Define Network
Define the network architecture and specify the average image value using the Mean
option in the image input layer.
layers = [ imageInputLayer([28 28 1],'Mean',mean(XTrain,4)) convolution2dLayer(5,20) reluLayer convolution2dLayer(3,20,'Padding',1) reluLayer convolution2dLayer(3,20,'Padding',1) reluLayer fullyConnectedLayer(numClasses) softmaxLayer];
Create a dlnetwork
object from the layer array.
net = dlnetwork(layers);
Define Model Loss Function
Create the helper function modelLoss
, listed at the end of the example. The function takes a dlnetwork
object and a mini-batch of input data with corresponding labels, and returns the loss and the gradients of the loss with respect to the learnable parameters.
Specify Training Options
Specify the options to use during training.
miniBatchSize = 128; numEpochs = 20; numObservations = numel(TTrain); numIterationsPerEpoch = floor(numObservations./miniBatchSize);
Train Network
Initialize the velocity parameter.
vel = [];
Calculate the total number of iterations for the training progress monitor.
numIterations = numEpochs * numIterationsPerEpoch;
Initialize the TrainingProgressMonitor
object. Because the timer starts when you create the monitor object, make sure that you create the object close to the training loop.
monitor = trainingProgressMonitor(Metrics="Loss",Info="Epoch",XLabel="Iteration");
Train the model using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. Update the network parameters using the sgdmupdate
function. At the end of each iteration, display the training progress.
Train on a GPU, if one is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
iteration = 0; epoch = 0; while epoch < numEpochs && ~monitor.Stop epoch = epoch + 1; % Shuffle data. idx = randperm(numel(TTrain)); XTrain = XTrain(:,:,:,idx); TTrain = TTrain(idx); i = 0; while i < numIterationsPerEpoch && ~monitor.Stop i = i + 1; iteration = iteration + 1; % Read mini-batch of data and convert the labels to dummy % variables. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; X = XTrain(:,:,:,idx); T = zeros(numClasses, miniBatchSize,"single"); for c = 1:numClasses T(c,TTrain(idx)==classes(c)) = 1; end % Convert mini-batch of data to a dlarray. X = dlarray(single(X),"SSCB"); % If training on a GPU, then convert data to a gpuArray. if canUseGPU X = gpuArray(X); end % Evaluate the model loss and gradients using dlfeval and the % modelLoss function. [loss,gradients] = dlfeval(@modelLoss,net,X,T); % Update the network parameters using the SGDM optimizer. [net,vel] = sgdmupdate(net,gradients,vel); % Update the training progress monitor. recordMetrics(monitor,iteration,Loss=loss); updateInfo(monitor,Epoch=epoch + " of " + numEpochs); monitor.Progress = 100 * iteration/numIterations; end end
Test the Network
Test the classification accuracy of the model by comparing the predictions on a test set with the true labels.
[XTest,TTest] = digitTest4DArrayData;
Convert the data to a dlarray
with the dimension format "SSCB"
(spatial, spatial, channel, batch). For GPU prediction, also convert the data to a gpuArray
.
XTest = dlarray(XTest,"SSCB"); if canUseGPU XTest = gpuArray(XTest); end
To classify images using a dlnetwork
object, use the predict
function and find the classes with the highest scores.
YTest = predict(net,XTest); [~,idx] = max(extractdata(YTest),[],1); YTest = classes(idx);
Evaluate the classification accuracy.
accuracy = mean(YTest==TTest)
accuracy = 0.9910
Model Loss Function
The modelLoss
function takes a dlnetwork
object net
and a mini-batch of input data X
with corresponding labels T
, and returns the loss and the gradients of the loss with respect to the learnable parameters in net
. To compute the gradients automatically, use the dlgradient
function.
function [loss,gradients] = modelLoss(net,X,T) Y = forward(net,X); loss = crossentropy(Y,T); gradients = dlgradient(loss,net.Learnables); end
Input Arguments
net
— Network
dlnetwork
object
Network, specified as a dlnetwork
object.
The function updates the Learnables
property of the
dlnetwork
object. net.Learnables
is a table with
three variables:
Layer
— Layer name, specified as a string scalar.Parameter
— Parameter name, specified as a string scalar.Value
— Value of parameter, specified as a cell array containing adlarray
.
The input argument grad
must be a table of the same
form as net.Learnables
.
params
— Network learnable parameters
dlarray
| numeric array | cell array | structure | table
Network learnable parameters, specified as a dlarray
, a numeric
array, a cell array, a structure, or a table.
If you specify params
as a table, it must contain the following
three variables.
Layer
— Layer name, specified as a string scalar.Parameter
— Parameter name, specified as a string scalar.Value
— Value of parameter, specified as a cell array containing adlarray
.
You can specify params
as a container of learnable parameters for
your network using a cell array, structure, or table, or nested cell arrays or
structures. The learnable parameters inside the cell array, structure, or table must be
dlarray
or numeric values of data type double
or
single
.
The input argument grad
must be provided with exactly the same
data type, ordering, and fields (for structures) or variables (for tables) as
params
.
The learnables can be complex-valued. (since R2024a) Ensure that the corresponding operations support complex-valued learnables.
Before R2024a: The learnables must not be complex-valued. If your model involves complex learnables, then convert the learnables to real values before calculating the gradients.
Data Types: single
| double
| struct
| table
| cell
grad
— Gradients of the loss
dlarray
| numeric array | cell array | structure | table
Gradients of the loss, specified as a dlarray
, a numeric array, a
cell array, a structure, or a table.
The exact form of grad
depends on the input network or learnable
parameters. The following table shows the required format for grad
for possible inputs to sgdmupdate
.
Input | Learnable Parameters | Gradients |
---|---|---|
net | Table net.Learnables containing
Layer , Parameter , and
Value variables. The Value variable
consists of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data type, variables, and ordering as
net.Learnables . grad must have a
Value variable consisting of cell arrays that contain the
gradient of each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering as
params
|
Numeric array | Numeric array with the same data type and ordering as
params
| |
Cell array | Cell array with the same data types, structure, and ordering as
params | |
Structure | Structure with the same data types, fields, and ordering as
params | |
Table with Layer , Parameter , and
Value variables. The Value variable must
consist of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data types, variables, and ordering as
params . grad must have a
Value variable consisting of cell arrays that contain the
gradient of each learnable parameter. |
You can obtain grad
from a call to dlfeval
that
evaluates a function that contains a call to dlgradient
.
For more information, see Use Automatic Differentiation In Deep Learning Toolbox.
The gradients can be complex-valued. (since R2024a) Using complex valued gradients can lead to complex-valued learnable parameters. Ensure that the corresponding operations support complex-valued learnables.
Before R2024a: The gradients must not be complex-valued. If your model involves complex numbers, then convert all outputs to real values before calculating the gradients.
vel
— Parameter velocities
[]
| dlarray
| numeric array | cell array | structure | table
Parameter velocities, specified as an empty array, a dlarray
, a
numeric array, a cell array, a structure, or a table.
The exact form of vel
depends on the input network or learnable
parameters. The following table shows the required format for vel
for
possible inputs to sgdmpdate
.
Input | Learnable Parameters | Velocities |
---|---|---|
net | Table net.Learnables containing
Layer , Parameter , and
Value variables. The Value variable
consists of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data type, variables, and ordering as
net.Learnables . vel must have a
Value variable consisting of cell arrays that contain the
velocity of each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering as
params
|
Numeric array | Numeric array with the same data type and ordering as
params
| |
Cell array | Cell array with the same data types, structure, and ordering as
params | |
Structure | Structure with the same data types, fields, and ordering as
params | |
Table with Layer , Parameter , and
Value variables. The Value variable must
consist of cell arrays that contain each learnable parameter as a
dlarray . | Table with the same data types, variables, and ordering as
params . vel must have a
Value variable consisting of cell arrays that contain the
velocity of each learnable parameter. |
If you specify vel
as an empty array, the function assumes no
previous velocities and runs in the same way as for the first update in a series of
iterations. To update the learnable parameters iteratively, use the
vel
output of a previous call to sgdmupdate
as
the vel
input.
The velocity can be complex-valued. (since R2024a) Using complex valued gradients and velocities can lead to complex-valued learnable parameters. Ensure that the corresponding operations support complex-valued learnables.
Before R2024a: The gradients must not be complex-valued. If your model involves complex numbers, then convert all outputs to real values before calculating the gradients.
learnRate
— Global learning rate
0.01
(default) | positive scalar
Learning rate, specified as a positive scalar. The default value of
learnRate
is 0.01
.
If you specify the network parameters as a dlnetwork
object, the
learning rate for each parameter is the global learning rate multiplied by the
corresponding learning rate factor property defined in the network layers.
momentum
— Momentum
0.9
(default) | positive scalar between 0
and 1
Momentum, specified as a positive scalar between 0
and
1
. The default value of momentum
is
0.9
.
Output Arguments
netUpdated
— Updated network
dlnetwork
object
Updated network, returned as a dlnetwork
object.
The function updates the Learnables
property of the
dlnetwork
object.
params
— Updated network learnable parameters
dlarray
| numeric array | cell array | structure | table
Updated network learnable parameters, returned as a dlarray
, a
numeric array, a cell array, a structure, or a table with a Value
variable containing the updated learnable parameters of the network.
The learnables can be complex-valued. (since R2024a) Ensure that the corresponding operations support complex-valued learnables.
Before R2024a: The learnables must not be complex-valued. If your model involves complex learnables, then convert the learnables to real values before calculating the gradients.
vel
— Updated parameter velocities
dlarray
| numeric array | cell array | structure | table
Updated parameter velocities, returned as a dlarray
, a numeric
array, a cell array, a structure, or a table.
Algorithms
Stochastic Gradient Descent
The standard gradient descent algorithm updates the network parameters (weights and biases) to minimize the loss function by taking small steps at each iteration in the direction of the negative gradient of the loss,
where is the iteration number, is the learning rate, is the parameter vector, and is the loss function. In the standard gradient descent algorithm, the gradient of the loss function, , is evaluated using the entire training set, and the standard gradient descent algorithm uses the entire data set at once.
By contrast, at each iteration the stochastic gradient descent algorithm evaluates the gradient and updates the parameters using a subset of the training data. A different subset, called a mini-batch, is used at each iteration. The full pass of the training algorithm over the entire training set using mini-batches is one epoch. Stochastic gradient descent is stochastic because the parameter updates computed using a mini-batch is a noisy estimate of the parameter update that would result from using the full data set.
Stochastic Gradient Descent with Momentum
The stochastic gradient descent algorithm can oscillate along the path of steepest descent towards the optimum. Adding a momentum term to the parameter update is one way to reduce this oscillation [1]. The stochastic gradient descent with momentum (SGDM) update is
where the learning rate α and the momentum value determine the contribution of the previous gradient step to the current iteration.
References
[1] Murphy, K. P. Machine Learning: A Probabilistic Perspective. The MIT Press, Cambridge, Massachusetts, 2012.
Extended Capabilities
GPU Arrays
Accelerate code by running on a graphics processing unit (GPU) using Parallel Computing Toolbox™.
The sgdmupdate
function
supports GPU array input with these usage notes and limitations:
When at least one of the following input arguments is a
gpuArray
or adlarray
with underlying data of typegpuArray
, this function runs on the GPU.grad
params
For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
Version History
Introduced in R2019bR2024a: Complex-valued learnable parameters and gradients
The learnable parameters, gradients, and velocity can be complex-valued. When the updated learnable parameters are complex-valued, ensure that the corresponding operations support complex-valued parameters.
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)
Asia Pacific
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)