Custom Loss Functions
Training a deep neural model is an optimization task. By considering a deep learning model as a function f(X;θ), where X is the model input, and θ is the set of learnable parameters, you can optimize θ so that it minimizes some loss value based on the training data. You typically optimize the learnable parameters θ such that for a given input X with corresponding targets T, the learnable parameters minimize the error between the predictions Y=f(X;θ) and T. For example, for regression and classification tasks, you can use cross-entropy and mean squared error (MSE) loss, respectively.
The trainnet
function provides several built-in loss functions to use for training. You can use
cross-entropy loss for classification and mean squared error loss for regression by
specifying "crossentropy" and "mse" as the lossFcn
argument, respectively.
For example, to train a neural network using the trainnet function
with cross-entropy loss,
use
net = trainnet(X,T,layers,"crossentropy",options);lossFcn
argument of the trainnet
function.If the trainnet function does not provide the loss function that you
need for your task, then you can define a custom loss function using one of these options:
Specify a loss function a function handle — Use this option when the loss function depends on the targets, neural network predictions, and optionally additional information from external sources. For example, you can specify a custom loss function that uses a weighted sum of neural network outputs. For more information, see Define Custom Loss Function for trainnet Function.
Specify a loss function using a custom training loop — Use this option when the loss function requires additional information from the training algorithm. For example, you can define a loss function for physics informed neural networks (PINNs) that evaluate second order derivatives. For more information, see Define Custom Loss Function for Custom Training Loops.
Define Custom Loss Function for trainnet Function
To define a custom loss function for the trainnet function.
Specify the loss function as a function handle that has the syntax loss =
f(Y1,...,Yn,T1,...,Tm), where Y1,...,Yn are
dlarray objects that correspond to the n network
predictions and T1,...,Tm are dlarray objects that
correspond to the m targets. The function must support automatic
differentiation using dlarray objects.
The loss value that the loss function outputs must be scalar. Most loss operations are element-wise, so when you implement the loss function, then you must define how to reduce an array of loss values to a scalar value. Many loss functions reduce the values to a scalar by taking the sum. Depending on the magnitude and size of the data that the loss function processes, the functions can scale the computed loss using a scaling factor based on the number elements of the data (for example, the number of observations, time steps, or pixels).
For named functions that already have this syntax, you can specify the function
directly. For example, if you have a function named odeLoss that has
the required syntax, then you can specify the loss function directly using the
syntax:
net = trainnet(X,T,layers,@odeLoss,options);
You can also specify the loss function as an anonymous function. This can be easier when you want to combine losses computed for different network outputs, independently. For example,
lossFcn = @(Y1,Y2,T1,T2) crossentropy(Y1,T1) + l2loss(Y2,T2); net = trainnet(X,T1,T2,net,lossFcn,options);
For loss functions that depend on external data, such as a fixed set of class weights, you can parameterize the function handle so that it only depends on the network predictions and the targets. For example, to specify a loss function for weighted cross-entropy, you can use:
lossFcn = @(Y,T) crossentropy(Y,T,Weights=weights); net = trainnet(X,T,layers,lossFcn,options);
Here, the loss function is an anonymous function with the inputs Y
and T only.
If you use an intermediate variable, then you can reuse the function to check the loss
values for debugging purpose. For example, to evaluate lossFcn on a
batch of test predictions and targets, you can
use
lossTest = lossFcn(YTest,TTest);
To speed up training, you can accelerate your custom loss function using the dlaccelerate function. You can then use the accelerated loss function
with the trainnet function. For
example,
accLossFcn = dlaccelerate(lossFcn); net = trainnet(X,T,layers,accLossFcn,options);
Not all deep learning functions fully support acceleration. For more information, see Deep Learning Function Acceleration.
Define Custom Loss Function for Custom Training Loops
For loss functions that require more inputs than just the predictions and targets (for example, loss functions that require access to the neural network or additional inputs), train the model using a custom training loop. For an example, see Train Network Using Custom Training Loop.
To train a deep learning model using a custom training loop, create a loss function that takes the neural network or model parameters, and training data as input, and returns the loss. Most deep learning training algorithms use the gradients of the loss with respect to the learnable parameters to perform the update steps, so you can include the gradients as output too.
To learn more about defining model loss functions for custom training loops, see Custom Training Loop Model Loss Functions.
Create Model Loss Function for Model Defined as dlnetwork Object
For a model specified as a dlnetwork object, create a function of the form
[loss,gradients] = modelLoss(net,X,T), where net
is the network, X is the network input, T contains the
targets, and loss and gradients are the returned loss
and gradients, respectively. Optionally, you can pass extra arguments to the gradients
function (for example, if the loss function requires extra information), or return extra
arguments (for example, the updated network state).
For example, this function returns the cross-entropy loss and the gradients of the loss with respect to the learnable parameters in the specified dlnetwork object net, given input data X, and targets T.
function [loss,gradients] = modelLoss(net,X,T) % Forward data through the dlnetwork object. Y = forward(net,X); % Compute loss. loss = crossentropy(Y,T); % Compute gradients. gradients = dlgradient(loss,net.Learnables); end
For an example showing how to train a neural network using a custom training loop, see Train Network Using Custom Training Loop.
To speed up training, you can accelerate your custom loss function using the dlaccelerate
function. For
example,
accLossFcn = dlaccelerate(@modelLoss);
Not all deep learning functions fully support acceleration. For more information, see Deep Learning Function Acceleration.
Create Model Loss Function for Model Defined as Function
For a model specified as a function, create a function of the form [loss,gradients] =
modelLoss(parameters,X,T), where parameters contains the
learnable parameters, X is the model input, T contains
the targets, and loss and gradients are the returned
loss and gradients, respectively. Optionally, you can pass extra arguments to the gradients
function (for example, if the loss function requires extra information), or return extra
arguments (for example, the updated model state).
For example, to compute the model loss and gradients for a model specified by the function
model and learnable parameters parameters,
use:
function [loss,gradients,state] = modelLoss(parameters,X,T) [Y,state] = model(parameters,X); loss = crossentropy(Y,T); gradients = dlgradient(loss,parameters); end
For an example showing how to train a deep learning model defined as a function using a custom training loop, see Train Network Using Model Function.
For more information, see Custom Training Loop Model Loss Functions.
To speed up training, you can accelerate your custom loss function using the dlaccelerate
function. For
example,
accLossFcn = dlaccelerate(@modelLoss);
Not all deep learning functions fully support acceleration. For more information, see Deep Learning Function Acceleration.
Functions for Building Custom Loss Functions
To help create a custom loss function, you can use the deep learning functions in this
table. You can also pass these functions to the trainnet function
directly as a function handle.
| Function | Description |
|---|---|
softmax | The softmax activation operation applies the softmax function to the channel dimension of the input data. |
sigmoid | The sigmoid activation operation applies the sigmoid function to the input data. |
crossentropy | The cross-entropy operation computes the cross-entropy loss between network predictions and binary or one-hot encoded targets for single-label and multi-label classification tasks. |
indexcrossentropy | The index cross-entropy operation computes the cross-entropy loss between network predictions and targets specified as integer class indices for single-label classification tasks. |
l1loss | The L1 loss operation computes the
L1 loss given network predictions and target values. When the
Reduction option is "sum" and the
NormalizationFactor option is "batch-size", the
computed value is known as the mean absolute error (MAE). |
l2loss | The L2 loss operation computes the
L2 loss (based on the squared L2 norm) given
network predictions and target values. When the Reduction option is
"sum" and the NormalizationFactor option is
"batch-size", the computed value is known as the mean squared error
(MSE). |
huber | The Huber operation computes the Huber loss between network predictions and target values for regression tasks. When the 'TransitionPoint' option is 1, this is also known as smooth L1 loss. |
ctc | The CTC operation computes the connectionist temporal classification (CTC) loss between unaligned sequences. |
mse | The half mean squared error operation computes the half mean squared error loss between network predictions and target values for regression tasks. |
See Also
trainnet | trainingOptions | dlnetwork | crossentropy | indexcrossentropy | l1loss | l2loss