Main Content

dlupdate

Update parameters using custom function

Description

netUpdated = dlupdate(fun,net) updates the learnable parameters of the dlnetwork object net by evaluating the function fun with each learnable parameter as an input. fun is a function handle to a function that takes one parameter array as an input argument and returns an updated parameter array.

example

params = dlupdate(fun,params) updates the learnable parameters in params by evaluating the function fun with each learnable parameter as an input.

[___] = dlupdate(fun,___A1,...,An) also specifies additional input arguments, in addition to the input arguments in previous syntaxes, when fun is a function handle to a function that requires n+1 input values.

[___,X1,...,Xm] = dlupdate(fun,___) returns multiple outputs X1,...,Xm when fun is a function handle to a function that returns m+1 output values.

Examples

collapse all

Perform L1 regularization on a structure of parameter gradients.

Create the sample input data.

dlX = dlarray(rand(100,100,3),'SSC');

Initialize the learnable parameters for the convolution operation.

params.Weights = dlarray(rand(10,10,3,50));
params.Bias = dlarray(rand(50,1));

Calculate the gradients for the convolution operation using the helper function convGradients, defined at the end of this example.

gradients = dlfeval(@convGradients,dlX,params);

Define the regularization factor.

L1Factor = 0.001;

Create an anonymous function that regularizes the gradients. By using an anonymous function to pass a scalar constant to the function, you can avoid having to expand the constant value to the same size and structure as the parameter variable.

L1Regularizer = @(grad,param) grad + L1Factor.*sign(param);

Use dlupdate to apply the regularization function to each of the gradients.

gradients = dlupdate(L1Regularizer,gradients,params);

The gradients in grads are now regularized according to the function L1Regularizer.

convGradients Function

The convGradients helper function takes the learnable parameters of the convolution operation and a mini-batch of input data dlX, and returns the gradients with respect to the learnable parameters.

function gradients = convGradients(dlX,params)
dlY = dlconv(dlX,params.Weights,params.Bias);
dlY = sum(dlY,'all');
gradients = dlgradient(dlY,params);
end

Use dlupdate to train a network using a custom update function that implements the stochastic gradient descent algorithm (without momentum).

Load Training Data

Load the digits training data.

[XTrain,TTrain] = digitTrain4DArrayData;
classes = categories(TTrain);
numClasses = numel(classes);

Define the Network

Define the network architecture and specify the average image value using the Mean option in the image input layer.

layers = [
    imageInputLayer([28 28 1],'Mean',mean(XTrain,4))
    convolution2dLayer(5,20)
    reluLayer
    convolution2dLayer(3,20,'Padding',1)
    reluLayer
    convolution2dLayer(3,20,'Padding',1)
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

Create a dlnetwork object from the layer array.

net = dlnetwork(layers);

Define Model Loss Function

Create the helper function modelLoss, listed at the end of this example. The function takes a dlnetwork object and a mini-batch of input data with corresponding labels, and returns the loss and the gradients of the loss with respect to the learnable parameters.

Define Stochastic Gradient Descent Function

Create the helper function sgdFunction, listed at the end of this example. The function takes the parameters and the gradients of the loss with respect to the parameters, and returns the updated parameters using the stochastic gradient descent algorithm, expressed as

θl+1=θ-αE(θl)

where l is the iteration number, α>0 is the learning rate, θ is the parameter vector, and E(θ) is the loss function.

Specify Training Options

Specify the options to use during training.

miniBatchSize = 128;
numEpochs = 30;
numObservations = numel(TTrain);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);

Specify the learning rate.

learnRate = 0.01;

Train Network

Calculate the total number of iterations for the training progress monitor.

numIterations = numEpochs * numIterationsPerEpoch;

Initialize the TrainingProgressMonitor object. Because the timer starts when you create the monitor object, make sure that you create the object close to the training loop.

monitor = trainingProgressMonitor(Metrics="Loss",Info="Epoch",XLabel="Iteration");

Train the model using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. Update the network parameters by calling dlupdate with the function sgdFunction defined at the end of this example. At the end of each epoch, display the training progress.

Train on a GPU, if one is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).

iteration = 0;
epoch = 0;

while epoch < numEpochs && ~monitor.Stop
    epoch = epoch + 1;
    
    % Shuffle data.
    idx = randperm(numel(TTrain));
    XTrain = XTrain(:,:,:,idx);
    TTrain = TTrain(idx);

    i = 0;
    while i < numIterationsPerEpoch && ~monitor.Stop
        i = i + 1;        
        iteration = iteration + 1;

        % Read mini-batch of data and convert the labels to dummy
        % variables.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        X = XTrain(:,:,:,idx);

        T = zeros(numClasses, miniBatchSize,"single");
        for c = 1:numClasses
            T(c,TTrain(idx)==classes(c)) = 1;
        end

        % Convert mini-batch of data to dlarray.
        X = dlarray(single(X),"SSCB");

        % If training on a GPU, then convert data to a gpuArray.
        if canUseGPU
            X = gpuArray(X);
        end

        % Evaluate the model loss and gradients using dlfeval and the
        % modelLoss function.
        [loss,gradients] = dlfeval(@modelLoss,net,X,T);

        % Update the network parameters using the SGD algorithm defined in
        % the sgdFunction helper function.
        updateFcn = @(net,gradients) sgdFunction(net,gradients,learnRate);
        net = dlupdate(updateFcn,net,gradients);

        % Update the training progress monitor.
        recordMetrics(monitor,iteration,Loss=loss);
        updateInfo(monitor,Epoch=epoch + " of " + numEpochs);
        monitor.Progress = 100 * iteration/numIterations;
    end
end

Test Network

Test the classification accuracy of the model by comparing the predictions on a test set with the true labels.

[XTest,TTest] = digitTest4DArrayData;

Convert the data to a dlarray with the dimension format "SSCB" (spatial, spatial, channel, batch). For GPU prediction, also convert the data to a gpuArray.

XTest = dlarray(XTest,"SSCB");
if canUseGPU
    XTest = gpuArray(XTest);
end

To classify images using a dlnetwork object, use the predict function and find the classes with the highest scores.

YTest = predict(net,XTest);
[~,idx] = max(extractdata(YTest),[],1);
YTest = classes(idx);

Evaluate the classification accuracy.

accuracy = mean(YTest==TTest)
accuracy = 0.9040

Model Loss Function

The helper function modelLoss takes a dlnetwork object net and a mini-batch of input data X with corresponding labels T, and returns the loss and the gradients of the loss with respect to the learnable parameters in net. To compute the gradients automatically, use the dlgradient function.

function [loss,gradients] = modelLoss(net,X,T)

Y = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss,net.Learnables);

end

Stochastic Gradient Descent Function

The helper function sgdFunction takes the learnable parameters parameters, the gradients of the loss with respect to the learnable parameters, and the learning rate learnRate, and returns the updated parameters using the stochastic gradient descent algorithm, expressed as

θl+1=θ-αE(θl)

where l is the iteration number, α>0 is the learning rate, θ is the parameter vector, and E(θ) is the loss function.

function parameters = sgdFunction(parameters,gradients,learnRate)

parameters = parameters - learnRate .* gradients;

end

Input Arguments

collapse all

Function to apply to the learnable parameters, specified as a function handle.

dlupdate evaluates fun with each network learnable parameter as an input. fun is evaluated as many times as there are arrays of learnable parameters in net or params.

Network, specified as a dlnetwork object.

The function updates the Learnables property of the dlnetwork object. net.Learnables is a table with three variables:

  • Layer — Layer name, specified as a string scalar.

  • Parameter — Parameter name, specified as a string scalar.

  • Value — Value of parameter, specified as a cell array containing a dlarray.

Network learnable parameters, specified as a dlarray, a numeric array, a cell array, a structure, or a table.

If you specify params as a table, it must contain the following three variables.

  • Layer — Layer name, specified as a string scalar.

  • Parameter — Parameter name, specified as a string scalar.

  • Value — Value of parameter, specified as a cell array containing a dlarray.

You can specify params as a container of learnable parameters for your network using a cell array, structure, or table, or nested cell arrays or structures. The learnable parameters inside the cell array, structure, or table must be dlarray or numeric values of data type double or single.

The input argument A1,...,An must be provided with exactly the same data type, ordering, and fields (for structures) or variables (for tables) as params.

Data Types: single | double | struct | table | cell

Additional input arguments to fun, specified as dlarray objects, numeric arrays, cell arrays, structures, or tables with a Value variable.

The exact form of A1,...,An depends on the input network or learnable parameters. The following table shows the required format for A1,...,An for possible inputs to dlupdate.

InputLearnable ParametersA1,...,An
netTable net.Learnables containing Layer, Parameter, and Value variables. The Value variable consists of cell arrays that contain each learnable parameter as a dlarray. Table with the same data type, variables, and ordering as net.Learnables. A1,...,An must have a Value variable consisting of cell arrays that contain the additional input arguments for the function fun to apply to each learnable parameter.
paramsdlarraydlarray with the same data type and ordering as params.
Numeric arrayNumeric array with the same data type and ordering as params.
Cell arrayCell array with the same data types, structure, and ordering as params.
StructureStructure with the same data types, fields, and ordering as params.
Table with Layer, Parameter, and Value variables. The Value variable must consist of cell arrays that contain each learnable parameter as a dlarray.Table with the same data types, variables and ordering as params. A1,...,An must have a Value variable consisting of cell arrays that contain the additional input argument for the function fun to apply to each learnable parameter.

Output Arguments

collapse all

Network, returned as a dlnetwork object.

The function updates the Learnables property of the dlnetwork object.

Updated network learnable parameters, returned as a dlarray, a numeric array, a cell array, a structure, or a table with a Value variable containing the updated learnable parameters of the network.

Additional output arguments from the function fun, where fun is a function handle to a function that returns multiple outputs, returned as dlarray objects, numeric arrays, cell arrays, structures, or tables with a Value variable.

The exact form of X1,...,Xm depends on the input network or learnable parameters. The following table shows the returned format of X1,...,Xm for possible inputs to dlupdate.

InputLearnable parametersX1,...,Xm
netTable net.Learnables containing Layer, Parameter, and Value variables. The Value variable consists of cell arrays that contain each learnable parameter as a dlarray. Table with the same data type, variables, and ordering as net.Learnables. X1,...,Xm has a Value variable consisting of cell arrays that contain the additional output arguments of the function fun applied to each learnable parameter.
paramsdlarraydlarray with the same data type and ordering as params.
Numeric arrayNumeric array with the same data type and ordering as params.
Cell arrayCell array with the same data types, structure, and ordering as params.
StructureStructure with the same data types, fields, and ordering as params.
Table with Layer, Parameter, and Value variables. The Value variable must consist of cell arrays that contain each learnable parameter as a dlarray.Table with the same data types, variables. and ordering as params. X1,...,Xm has a Value variable consisting of cell arrays that contain the additional output argument of the function fun applied to each learnable parameter.

Extended Capabilities

Version History

Introduced in R2019b