Specify Training Options in Custom Training Loop
For most tasks, you can control the training algorithm details using the trainingOptions
and trainnet
functions. If the trainingOptions
function does not provide the
options you need for your task (for example, a custom solver), then you can define your own
custom training loop.
To specify the same options as the trainingOptions
, use these examples as a guide:
Training Option | trainingOptions Argument | Example |
---|---|---|
Adam solver | Adaptive Moment Estimation (ADAM) | |
RMSProp solver | Root Mean Square Propagation (RMSProp) | |
SGDM solver | Stochastic Gradient Descent with Momentum (SGDM) | |
LBFGS Solver | ||
Learn rate | InitialLearnRate | Learn Rate |
Learn rate schedule | Piecewise Learn Rate Schedule | |
Training progress | Plots | Plots |
Verbose output | Verbose Output | |
Mini-batch size | MiniBatchSize | Mini-Batch Size |
Number of epochs | MaxEpochs | Number of Epochs |
Validation | Validation | |
L2 regularization | L2Regularization | L2 Regularization |
Gradient clipping | Gradient Clipping | |
Single CPU or GPU training | ExecutionEnvironment | Single CPU or GPU Training |
Checkpoints | CheckpointPath | Checkpoints |
Solver Options
To specify the solver, use the adamupdate
,
rmspropupdate
, and sgdmupdate
functions for the update step in your training loop. To implement your own custom
solver, update the learnable parameters using the dlupdate
function.
Adaptive Moment Estimation (ADAM)
To update your network parameters using Adam, use the adamupdate
function. Specify the gradient decay and the squared
gradient decay factors using the corresponding input arguments.
Root Mean Square Propagation (RMSProp)
To update your network parameters using RMSProp, use the rmspropupdate
function. Specify the denominator offset (epsilon)
value using the corresponding input argument.
Stochastic Gradient Descent with Momentum (SGDM)
To update your network parameters using SGDM, use the sgdmupdate
function. Specify the momentum using the corresponding
input argument.
Limited-Memory Broyden-Fletcher-Goldfarb-Shanno (L-BFGS)
To update your network parameters using L-BFGS, use the lbfgsupdate
function with an lbfgsState
object. Specify the line search method and the maximum
number of line search iterations using the corresponding arguments of the
lbfgsupdate
function and specify the history size and the
initial inverse Hessian factor using the corresponding properties of the
lbfgsState
object.
To configure early stopping based on the number of iterations, adapt the technique
used in Number of Epochs to specify the
maximum number of iterations. To stop training early based on the gradient or step
differences, adapt the techniques used in Validation to check the
gradient and step differences by using the GradientNorm
and the
StepNorm
, respectively.
Learn Rate
To specify the learn rate, use the learn rate input arguments of the adamupdate
,
rmspropupdate
, and sgdmupdate
functions.
To easily adjust the learn rate or use it for custom learn rate schedules, set the initial learn rate before the custom training loop.
learnRate = 0.01;
Piecewise Learn Rate Schedule
To automatically drop the learn rate during training using a piecewise learn rate schedule, multiply the learn rate by a given drop factor after a specified interval.
To easily specify a piecewise learn rate schedule, create the variables
learnRate
, learnRateSchedule
,
learnRateDropFactor
, and
learnRateDropPeriod
, where learnRate
is
the initial learn rate, learnRateSchedule
contains either
"piecewise"
or "none"
,
learnRateDropFactor
is a scalar in the range [0, 1] that
specifies the factor for dropping the learning rate, and
learnRateDropPeriod
is a positive integer that specifies how
many epochs between dropping the learn
rate.
learnRate = 0.01;
learnRateSchedule = "piecewise"
learnRateDropPeriod = 10;
learnRateDropFactor = 0.1;
Inside the training loop, at the end of each epoch, drop the learn rate when the
learnRateSchedule
option is "piecewise"
and the current epoch number is a multiple of
learnRateDropPeriod
. Set the new learn rate to the product of the
learn rate and the learn rate drop
factor.
if learnRateSchedule == "piecewise" && mod(epoch,learnRateDropPeriod) == 0 learnRate = learnRate * learnRateDropFactor; end
Plots
To plot the training loss and accuracy during training, calculate the mini-batch loss
and either the accuracy or the root-mean-squared-error (RMSE) in the model loss function
and plot them using a TrainingProgressMonitor
object.
To easily specify that the plot should be on or off, set the
Visible
property of the
TrainingProgressMonitor
object. By default,
Visible
is set to true
. When
Visible
is set to false
, the software logs
the training metrics and information but does not display the Training Progress window.
You can display the Training Progress window after training by changing the
Visible
property. To also plot validation metrics, use the same
validationFrequency
as described in Validation.
validationFrequency = 50;
Before training, initialize a TrainingProgressMonitor
object. The
monitor automatically tracks the elapsed time since the construction of the object. To
use this elapsed time as a proxy for training time, make sure you create the
TrainingProgressMonitor
object close to the start of the training
loop.
For classification tasks, create a plot to track the loss and accuracy for the training and validation data. Also track the epoch number and the training progress percentage.
monitor = trainingProgressMonitor; monitor.Metrics = ["TrainingAccuracy" "ValidationAccuracy" "TrainingLoss" "ValidationLoss"]; groupSubPlot(monitor,"Accuracy",["TrainingAccuracy","ValidationAccuracy"]); groupSubPlot(monitor,"Loss",["TrainingLoss","ValidationLoss"]); monitor.Info = "Epoch"; monitor.XLabel = "Iteration"; monitor.Progress = 0;
For regression tasks, adjust the code by changing the variable names and labels so that it initializes plots for the training and validation RMSE instead of the training and validation accuracy.
Inside the training loop, at the end of an iteration, use the recordMetrics
and updateInfo
functions to include the appropriate metrics and information for the training loop. For
classification tasks, add points corresponding to the mini-batch accuracy and the
mini-batch loss. If the current iteration is either 1 or a multiple of the validation
frequency option, then also add points for the validation
data.
recordMetrics(monitor,iteration, ... TrainingLoss=lossTrain, ... TrainingAccuracy=accuracyTrain); updateInfo(monitor,Epoch=string(epoch) + " of " + string(numEpochs)); if iteration == 1 || mod(iteration,validationFrequency) == 0 recordMetrics(monitor,iteration, ... ValidationLoss=lossValidation, ... ValidationAccuracy=accuracyValidation); end monitor.Progress = 100*iteration/numIterations;
accuracyTrain
and lossTrain
correspond to the
mini-batch accuracy and loss calculated in the model loss function. For regression
tasks, use the mini-batch RMSE losses instead of the mini-batch accuracies.You can stop training using the Stop button in the Training
Progress window. When you click Stop, the Stop
property of the monitor changes to 1
(true
).
Training stops if your training loop exits when the Stop
property is
1
.
while numEpochs < maxEpochs && ~monitor.Stop % Custom training loop code. end
For more information about plotting and recording metrics during training, see Monitor Custom Training Loop Progress During Training.
To learn how to compute validation metrics, see Validation.
Verbose Output
To display the training loss and accuracy during training in a verbose table,
calculate the mini-batch loss and either the accuracy (for classification tasks) or the
RMSE (for regression tasks) in the model loss function and display them using the
disp
function.
To easily specify that the verbose table should be on or off, create the variables
verbose
and verboseFrequency
, where
verbose
is true
or false
and verbosefrequency
specifies how many iterations between printing
verbose output. To display validation metrics, use the same
validationFrequency
as described in Validation.
verbose = true verboseFrequency = 50; validationFrequency = 50;
Before training, display the verbose output table headings and initialize a timer
using the tic
function.
fprintf("|======================================================================================================================|\n") fprintf("| Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning |\n") fprintf("| | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate |\n") fprintf("|======================================================================================================================|\n") start = tic;
Inside the training loop, at the end of an iteration, print the verbose output when
the verbose
option is true
and it is either the
first iteration or the iteration number is a multiple of
verboseFrequency
.
if verbose && (iteration == 1 || mod(iteration,verboseFrequency) == 0) D = duration(0,0,toc(start),Format="hh:mm:ss"); fprintf("| %7d | %11d | %14s | %12.4f | %12.4f | %12.4f | %12.4f | %15.4f |\n", ... epoch,iteration,D,accuracyTrain,accuracyValidation,lossTrain,lossValidation,learnRate) end
For regression tasks, adjust the code so that it displays the training and validation RMSE instead of the training and validation accuracy.
When training is finished, print the last border of the verbose table.
fprintf("|======================================================================================================================|\n")
To learn how to compute validation metrics, see Validation.
Mini-Batch Size
Setting the mini-batch size depends on the format of data or type of datastore used.
To easily specify the mini-batch size, create a variable
miniBatchSize
.
miniBatchSize = 128;
For data in an image datastore, before training, set the ReadSize
property of the datastore to the mini-batch
size.
imds.ReadSize = miniBatchSize;
For data in an augmented image datastore, before training, set the
MiniBatchSize
property of the datastore to the mini-batch
size.
augimds.MiniBatchSize = miniBatchSize;
For in-memory data, during training at the start of each iteration, read the observations directly from the array.
idx = ((iteration - 1)*miniBatchSize + 1):(iteration*miniBatchSize); X = XTrain(:,:,:,idx);
Number of Epochs
Specify the maximum number of epochs for training in the outer loop of the training loop.
To easily specify the maximum number of epochs, create the variable
maxEpochs
that contains the maximum number of
epochs.
maxEpochs = 30;
In the outer loop of the training loop, specify to loop over the range 1, 2, …,
maxEpochs
.
epoch = 0; while numEpochs < maxEpochs epoch = epoch + 1; ... end
Validation
To validate your network during training, set aside a held-out validation set and evaluate how well the network performs on that data.
To easily specify validation options, create a variable
validationFrequency
that specifies how many iterations between
validating the
network.
validationFrequency = 50;
During the training loop, after updating the network parameters, test how well the
network performs on the held-out validation set using the testnet
function (since R2024b). Validate the network only when validation data is specified and
it is either the first iteration or the current iteration is a multiple of the
validationFrequency
option.
if iteration == 1 || mod(iteration,validationFrequency) == 0 metrics = testnet(net,XValidation,TValidation,["crossentropy" "accuracy"], ... MiniBatchSize = miniBatchSize) lossValidation = metrics(1); accuracyValidation = metrics(2); end
Evaluate different metrics using the testnet
function based on
your task. For example, for a regression task, adjust the code so that it evaluates only
the mean squared error.
For an example showing how to calculate and plot validation metrics during training, see Monitor Custom Training Loop Progress During Training.
Before R2024b: To test how well the network performs on the
held-out validation set, use the predict
function to make
predictions with the validation data, and calculate the loss between the predictions and
the target values using a loss function, such as crossentropy
or mse
.
Early Stopping
To stop training early when the loss on the held-out validation stops decreasing, use a flag to break out of the training loops.
To easily specify the validation patience (the number of times that the validation
loss can be larger than or equal to the previously smallest loss before network
training stops), create the variable
validationPatience
.
validationPatience = 5;
Before training, initialize a variables earlyStop
and
validationLosses
, where earlyStop
is a
flag to stop training early and validationLosses
contains the
losses to compare. Initialize the early stopping flag with false
and array of validation losses with
inf
.
earlyStop = false; if isfinite(validationPatience) validationLosses = inf(1,validationPatience); end
Inside the training loop, in the loop over mini-batches, add the
earlyStop
flag to the loop
condition.
while hasdata(ds) && ~earlyStop ... end
During the validation step, append the new validation loss to the array
validationLosses
. If the first element of the array is the
smallest, then set the earlyStop
flag to true
.
Otherwise, remove the first
element.
if isfinite(validationPatience) validationLosses = [validationLosses validationLoss]; if min(validationLosses) == validationLosses(1) earlyStop = true; else validationLosses(1) = []; end end
L2 Regularization
To apply L2 regularization to the weights, use the
dlupdate
function.
To easily specify the L2 regularization factor, create the
variable l2Regularization
that contains the L2
regularization
factor.
l2Regularization = 0.0001;
During training, after computing the model loss and gradients, for each of the weight
parameters, add the product of the L2 regularization factor and
the weights to the computed gradients using the dlupdate
function.
To update only the weight parameters, extract the parameters with name
"Weights"
.
idx = net.Learnables.Parameter == "Weights";
gradients(idx,:) = dlupdate(@(g,w) g + l2Regularization*w, gradients(idx,:), net.Learnables(idx,:));
After adding the L2 regularization parameter to the gradients, update the network parameters.
Gradient Clipping
To clip the gradients, use the dlupdate
function.
To easily specify gradient clipping options, create the variables
gradientThresholdMethod
and gradientThreshold
,
where gradientThresholdMethod
contains
"global-l2norm"
, "l2norm"
, or
"absolute-value"
, and gradientThreshold
is a
positive scalar containing the threshold or
inf
.
gradientThresholdMethod = "global-l2norm";
gradientThreshold = 2;
Create functions named thresholdGlobalL2Norm
,
thresholdL2Norm
, and thresholdAbsoluteValue
that apply the "global-l2norm"
, "l2norm"
, and
"absolute-value"
threshold methods, respectively.
For the "global-l2norm"
option, the function operates on all
gradients of the
model.
function gradients = thresholdGlobalL2Norm(gradients,gradientThreshold) globalL2Norm = 0; for i = 1:numel(gradients) globalL2Norm = globalL2Norm + sum(gradients{i}(:).^2); end globalL2Norm = sqrt(globalL2Norm); if globalL2Norm > gradientThreshold normScale = gradientThreshold / globalL2Norm; for i = 1:numel(gradients) gradients{i} = gradients{i} * normScale; end end end
For the "l2norm"
and "absolute-value"
options,
the functions operate on each gradient
independently.
function gradients = thresholdL2Norm(gradients,gradientThreshold) gradientNorm = sqrt(sum(gradients(:).^2)); if gradientNorm > gradientThreshold gradients = gradients * (gradientThreshold / gradientNorm); end end
function gradients = thresholdAbsoluteValue(gradients,gradientThreshold) gradients(gradients > gradientThreshold) = gradientThreshold; gradients(gradients < -gradientThreshold) = -gradientThreshold; end
During training, after computing the model loss and gradients, apply the appropriate
gradient clipping method to the gradients using the dlupdate
function. Because the "global-l2norm"
option requires all the
gradient values, apply the thresholdGlobalL2Norm
function directly to
the gradients. For the "l2norm"
and
"absolute-value"
options, update the gradients independently
using the dlupdate
function.
switch gradientThresholdMethod case "global-l2norm" gradients = thresholdGlobalL2Norm(gradients, gradientThreshold); case "l2norm" gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients); case "absolute-value" gradients = dlupdate(@(g) thresholdAbsoluteValue(g, gradientThreshold),gradients); end
After applying the gradient threshold operation, update the network parameters.
Single CPU or GPU Training
The software, by default, performs calculations using only the CPU. To train on a
single GPU, convert the data to gpuArray
objects. Using a GPU requires
a Parallel Computing Toolbox™ license and a supported GPU device. For information about supported devices, see
GPU Computing Requirements (Parallel Computing Toolbox).
To easily specify the execution environment, create the variable executionEnvironment
that contains either "cpu"
, "gpu"
, or "auto"
.
executionEnvironment = "auto"
During training, after reading a mini-batch, check the execution environment option and
convert the data to a gpuArray
if necessary. The canUseGPU
function checks for useable
GPUs.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" X = gpuArray(X); end
Checkpoints
To save checkpoint networks during training save the network using the
save
function.
To easily specify whether checkpoints should be switched on, create the variable
checkpointPath
contains the folder for the checkpoint networks or
is
empty.
checkpointPath = fullfile(tempdir,"checkpoints");
If the checkpoint folder does not exist, then before training, create the checkpoint folder.
if ~exist(checkpointPath,"dir") mkdir(checkpointPath) end
During training, at the end of an epoch, save the network in a MAT file. Specify a file name containing the current iteration number, date, and time.
if ~isempty(checkpointPath) D = string(datetime("now",Format="yyyy_MM_dd__HH_mm_ss")); filename = "net_checkpoint__" + iteration + "__" + D + ".mat"; save(filename,"net") end
net
is the dlnetwork
object to be saved.See Also
adamupdate
| rmspropupdate
| sgdmupdate
| dlupdate
| dlarray
| dlgradient
| dlfeval
| dlnetwork
Related Topics
- Define Custom Training Loops, Loss Functions, and Networks
- Define Model Loss Function for Custom Training Loop
- Train Network Using Custom Training Loop
- Train Network Using Model Function
- Make Predictions Using dlnetwork Object
- Make Predictions Using Model Function
- Initialize Learnable Parameters for Model Function
- Update Batch Normalization Statistics in Custom Training Loop
- Update Batch Normalization Statistics Using Model Function
- Train Generative Adversarial Network (GAN)
- List of Functions with dlarray Support