Train Model Using Custom Backward Function
This example shows how to train a deep learning model that contains an operation with a custom backward function.
When you define a custom loss function, custom layer forward function, or define a deep learning model as a function, if the software does not provide the deep learning operation that you require for your task, then you can define your own function using dlarray
objects.
Most deep learning workflows use gradients to train the model. If the function only uses functions that support dlarray
objects, then you can use the functions directly and the software determines the gradients automatically using automatic differentiation. For example, you can pass dlarray
object functions like crossentropy
to as a loss function to the trainnet
function, or use dlarray
object functions like dlconv
in custom layer functions. For a list of functions that support dlarray
objects, see List of Functions with dlarray Support.
If you want to use functions that do not support dlarray
objects, or want to use a specific algorithm to compute the gradients, then you can define a custom deep learning operation as a differentiable function object. This example trains a simple classification neural network, defined as a function, which uses a custom SReLU [1] operation with a custom backward function.
For an example showing how to create the custom function, see Specify Custom Operation Backward Function.
Load Training Data
Load the digits data. The training set contains 5000 images of handwritten digits and their corresponding digit labels and angles of rotation.
load digitsDataTrain
View the class names of the data set.
classNames = categories(labelsTrain)
classNames = 10×1 cell
{'0'}
{'1'}
{'2'}
{'3'}
{'4'}
{'5'}
{'6'}
{'7'}
{'8'}
{'9'}
Model Parameters
Define the parameters for each of the operations and include them in a struct. Use the format parameters.OperationName.ParameterName
where parameters
is the structure, OperationName
is the name of the operation (for example "conv"
) and ParameterName
is the name of the parameter (for example, "Weights"
).
Create an empty structure for the learnable parameters.
parameters = struct;
Initialize the learnable weights and biases using the example functions like initializeGlorot
and initializeHe
. To access these functions, open the example as a live script.
Initialize the weights and biases for the convolution operation "conv"
using initializeGlorot
, and initializeZeros
, respectively.
filterSize = [5 5]; numFilters = 20; numChannels = size(XTrain,3); numOut = numFilters*prod(filterSize); numIn = numChannels*prod(filterSize); sz = [filterSize(1) filterSize(2) numChannels numFilters]; parameters.conv.Weights = initializeGlorot(sz,numOut,numIn); parameters.conv.Bias = initializeZeros([numFilters 1]);
Initialize the offset and scale for the layer normalization operation "layernorm"
using initializeZeros
, and initializeOnes
, respectively.
parameters.layernorm.Offset = initializeZeros([numFilters 1]); parameters.layernorm.Scale = initializeOnes([numFilters 1]);
Initialize the parameters for the SReLU operation "srelu"
using initializeHe
.
numIn = numFilters; sz = [1 1 numIn]; parameters.srelu.LeftThreshold = initializeHe(sz,numIn); parameters.srelu.LeftSlope = initializeHe(sz,numIn); parameters.srelu.RightThreshold = initializeHe(sz,numIn); parameters.srelu.RightSlope = initializeHe(sz,numIn);
Initialize the weights and biases for the fully connect operation using initializeGlorot
, and initializeZeros
, respectively.
numClasses = numel(classNames); numOut = numClasses; numIn = 15680; sz = [numOut numIn]; parameters.fc.Weights = initializeGlorot(sz,numOut,numIn); parameters.fc.Bias = initializeZeros([numOut 1]);
Create Custom SReLU Function
Create a custom function that applies the SReLU operation. To specify a custom backward function, create a sreluFunction
object, using the class definition attached to this example as a supporting file. To access this file, open this example as a live script. Specify the data format using the first argument of sreluFunction
.
function Y = srelu(X,tl,al,tr,ar) format = dims(X); fcn = sreluFunction(format); Y = fcn(X,tl,al,tr,ar); Y = dlarray(Y,format); end
Define Model Function
Create the function model
that takes the learnable parameters and input data as input and returns the model output. The model applies the convolution, layer normalization, SReLU, fully connect, and softmax operations to the input data.
function Y = model(parameters,X) weights = parameters.conv.Weights; bias = parameters.conv.Bias; Y = dlconv(X,weights,bias,Padding="same"); offset = parameters.layernorm.Offset; scaleFactor = parameters.layernorm.Scale; Y = layernorm(Y,offset,scaleFactor); tl = parameters.srelu.LeftThreshold; al = parameters.srelu.LeftSlope; tr = parameters.srelu.RightThreshold; ar = parameters.srelu.RightSlope; Y = srelu(Y,tl,al,tr,ar); weights = parameters.fc.Weights; bias = parameters.fc.Bias; Y = fullyconnect(Y,weights,bias); Y = softmax(Y); end
Define Model Loss Function
Create the function modelLoss
that takes the model parameters, a mini-batch of input data X
with corresponding targets T
, and returns the loss and the gradients of the loss with respect to the learnable parameters.
function [loss,gradients] = modelLoss(parameters,X,T) Y = model(parameters,X); loss = crossentropy(Y,T); gradients = dlgradient(loss,parameters); end
Specify Training Options
Specify the training options. Train for 20 epochs with a mini-batch size of 128.
numEpochs = 20; miniBatchSize = 128;
Train Model
Use a minibatchqueue
object to process and manage the mini-batches of images. For each mini-batch:
Use the custom mini-batch preprocessing function
preprocessMiniBatch
(defined at the end of this example) to one-hot encode the class labels.Format the image data with the dimension labels
"SSCB"
(spatial, spatial, channel, batch). By default, theminibatchqueue
object converts the data todlarray
objects with underlying typesingle
. Do not add a format to the class labels or angles.Train on a GPU if one is available. By default, the
minibatchqueue
object converts each output to agpuArray
if a GPU is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information about supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
adsXTrain = arrayDatastore(XTrain,IterationDimension=4); adsTTrain = arrayDatastore(labelsTrain); cdsTrain = combine(adsXTrain,adsTTrain); mbq = minibatchqueue(cdsTrain,... MiniBatchSize=miniBatchSize,... MiniBatchFcn=@preprocessMiniBatch,... MiniBatchFormat=["SSCB" ""]);
Create a mini-batch preprocessing function that concatenates the input data and one-hot encodes the targets.
function [X,T] = preprocessMiniBatch(dataX,dataT) X = cat(4,dataX{:}); T = cat(2,dataT{:}); T = onehotencode(T,1); end
Train using the Adam solver. Initialize the training parameters for Adam.
trailingAvg = []; trailingAvgSq = [];
To monitor training, create a training progress monitor. To update the progress bar of the training progress monitor, calculate the total number of training iterations.
numObservations = size(XTrain,4); numIterationsPerEpoch = ceil(numObservations/miniBatchSize); numIterations = numIterationsPerEpoch * numEpochs; monitor = trainingProgressMonitor( ... Metrics="Loss", ... Info="Epoch", ... XLabel="Iteration");
Train the model using a custom training loop.
epoch = 0; iteration = 0; while epoch < numEpochs && ~monitor.Stop epoch = epoch + 1; shuffle(mbq) while hasdata(mbq) && ~monitor.Stop iteration = iteration + 1; [X,T] = next(mbq); [loss,gradients] = dlfeval(@modelLoss,parameters,X,T); [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ... trailingAvg,trailingAvgSq,iteration); recordMetrics(monitor,iteration,Loss=loss); updateInfo(monitor,Epoch=(epoch+" of "+numEpochs)); monitor.Progress = 100 * iteration/numIterations; end end
Test Model
Test the model by evaluating the classification accuracy on the test data set.
load digitsDataTest XTest = dlarray(XTest,"SSCB"); scoresTest = model(parameters,XTest); YTest = scores2label(scoresTest,classNames); acc = mean(labelsTest==YTest')
acc = 0.9882
References
Hu, Xiaobin, Peifeng Niu, Jianmei Wang, and Xinxin Zhang. “A Dynamic Rectified Linear Activation Units.” IEEE Access 7 (2019): 180409–16. https://doi.org/10.1109/ACCESS.2019.2959036.
See Also
dlarray
| dlgradient
| dlfeval
| trainnet
| trainingOptions
| dlnetwork
| functionLayer
Related Topics
- Define Custom Deep Learning Operations
- Specify Custom Operation Backward Function
- Define Custom Deep Learning Layers
- Specify Custom Layer Backward Function
- List of Functions with dlarray Support
- Train Network Using Model Function
- Update Batch Normalization Statistics Using Model Function
- Initialize Learnable Parameters for Model Function