focalCrossEntropy
Syntax
Description
computes the focal cross-entropy between network predictions and target values for
single-label and multi-label classification tasks. The classes are mutually-exclusive
classes. The focal cross-entropy loss weights towards poorly classified training samples and
ignores well-classified samples. The focal cross-entropy loss is computed as the average
logarithmic loss divided by number of non-zero targets.dlY
= focalCrossEntropy(dlX
,targets
)
specifies options using one or more name-value arguments in addition to the input arguments
in previous syntaxes. For example, dlY
= focalCrossEntropy(___,Name=Value
)ClassificationMode="multilabel"
computes the cross-entropy loss for a multi-label classification task.
Examples
Compute Focal Cross-Entropy Loss Using Formatted dlarray
Create the input classification data as 32 observations of random variables belonging to 10 classes or categories.
numCategories = 10; observations = 32; X = rand(numCategories,observations);
Create a formatted deep learning array that has a data format with the labels 'C'
and 'B'
.
dlX = dlarray(X,'CB');
Use the softmax
function to set all values in the input data to values between 0
and 1
that sum to 1
over all channels. The values specify the probability of each observation to belong to a particular category.
dlX = softmax(dlX);
Create the target data as unformatted deep learning array, which holds the correct category for each observation in dlX. Set the targets belonging to the second category as one-hot encoded vectors.
targets = dlarray(zeros(numCategories,observations)); targets(2,:) = 1;
Compute the focal cross-entropy loss between each prediction and the target.
dlY = focalCrossEntropy(dlX,targets,'Reduction','none');
Compute Average Focal Cross-Entropy Loss Using Unformatted dlarray
Create the input classification data as 32 observations of random variables belonging to 10 classes or categories.
numCategories = 10; observations = 32; X = rand(numCategories,observations);
Create an unformatted deep learning array.
dlX = dlarray(X);
Use the softmax
function to set all values in the input data to values between 0
and 1
that sum to 1
over all channels. The values specify the probability for each observation to belong to a particular category.
dlX = softmax(dlX,'DataFormat','CB');
Create the target data. Set the targets belonging to the second category as one-hot encoded vectors.
targets = zeros(numCategories,observations); targets(2,:) = 1;
Compute the average of focal cross-entropy loss computed between the predictions and the targets.
dlY = focalCrossEntropy(dlX,targets,'DataFormat','CB')
dlY = 1x1 dlarray 0.4769
Compute Average Focal Cross-Entropy Loss for Multi-Label Classification
Create the input classification data as 32 observations of random variables belonging to 10 classes or categories.
numCategories = 10; observations = 32; X = rand(numCategories,observations);
Create a formatted deep learning array that has a data format with the labels 'C' and 'B'.
dlX = dlarray(X,'CB');
Use the sigmoid
function to set all values in the input data to values between 0
and 1
that sum to 1
over all channels. The values specify the probability of each observation to belong to a particular category.
dlX = sigmoid(dlX);
Create the target data, which holds the correct category for each observation in dlX
. Set the targets belonging to the second and sixth category as one-hot encoded vectors.
targets = zeros(numCategories,observations); targets(2,:) = 1; targets(6,:) = 1;
Compute the average of focal cross-entropy loss computed between the predictions and the targets. Set the 'ClassificationMode
' value to 'multilabel
' for multi-label classification.
dlY = focalCrossEntropy(dlX,targets,'ClassificationMode','multilabel')
dlY = 1x1 dlarray 2.4362
Input Arguments
dlX
— Predictions
dlarray
| numeric array
Predictions, specified as a dlarray
with or without dimension
labels or a numeric array. When dlX
is not a formatted
dlarray
, you must specify the dimension format using the
DataFormat
name-value argument. If dlX
is a
numeric array, targets
must be a dlarray
.
Data Types: single
| double
targets
— Target classification labels
dlarray
| numeric array
Target classification labels, specified as a formatted or unformatted
dlarray
or a numeric array.
If targets
is a formatted dlarray
, its
dimension format must be the same as the format of dlX
, or the same
as the DataFormat
name-value argument if dlX
is
unformatted.
If targets
is an unformatted dlarray
or a
numeric array, the size of targets
must exactly match the size of
dlX
. The format of dlX
or the value of
DataFormat
is implicitly applied to
targets
.
Data Types: single
| double
Name-Value Arguments
Specify optional pairs of arguments as
Name1=Value1,...,NameN=ValueN
, where Name
is
the argument name and Value
is the corresponding value.
Name-value arguments must appear after other arguments, but the order of the
pairs does not matter.
Example: ClassificationMode="multilabel",DataFormat="CB"
evaluates the
focal cross-entropy loss for multi-label classification tasks and specifies the dimension
order of the input data as "CB"
Before R2021a, use commas to separate each name and value, and enclose
Name
in quotes.
Example: "ClassificationMode","multilabel","DataFormat","CB"
evaluates
the focal cross-entropy loss for multi-label classification tasks and specifies the
dimension order of the input data as "CB"
Gamma
— Focusing parameter
2
(default) | positive real number
Focusing parameter of the focal loss function, specified as a positive real
number. Increasing the value of Gamma
increases the sensitivity of
the network to misclassified observations.
Alpha
— Balancing parameter
0.25
(default) | positive real number
Balancing parameter of the focal loss function, specified as a positive real
number. The Alpha
value scales the loss function linearly and is
typically set to 0.25
. If you decrease Alpha
,
increase Gamma
.
Reduction
— Type of output loss
"mean"
(default) | "none"
Type of output loss, specified as one of the following:
"mean"
— Average of output loss for each prediction. The function computes the average of loss values computed for each predictions in inputdlX
. The function returns the average loss as an unformatteddlarray
. Observations with all zero target values along the channel dimension are excluded from computing the average loss."none"
— Output loss for each prediction. The function returns the loss values for each observation indlX
. The samples for computing focal cross-entropy loss also contains observations whose target values are all zeros along the channel dimension. IfdlX
is a formatteddlarray
, outputdlY
is a formatteddlarray
with same dimension labels asdlX
. IfdlX
is an unformatteddlarray
, outputdlY
is an unformatteddlarray
.
Data Types: char
| string
ClassificationMode
— Type of classification task
"single-label"
(default) | "multilabel"
Type of classification task, specified as one of these values:
"single-label"
— Each observation in the predictionsdlX
is exclusively assigned to one category (single-label classification)."multilabel"
— Each observation in the predictionsdlX
can be assigned to one or more independent categories (multilabel classification).
Note
To select the classification mode for binary classification, you must consider the output layer of the network:
If the final layer has an output size of one, such as with a sigmoid layer, use
"multilabel"
.If the final layer has an output size of two, such as with a softmax layer, use
"single-label"
.
DataFormat
— Dimension order of unformatted data
char vector | string
Dimension order of unformatted input data, specified as a character vector or
string FMT
that provides a label for each dimension of the data.
Each character in FMT
must be one of the following:
"S"
— Spatial"C"
— Channel"B"
— Batch (for example, samples and observations)"T"
— Time (for example, sequences)"U"
— Unspecified
You can specify multiple dimensions labeled "S"
or
"U"
. You can use the labels "C"
,
"B"
, and "T"
at most once.
You must specify DataFormat
when the input data
dlX
is not a formatted dlarray
.
Example: DataFomat="SSCB"
Data Types: char
| string
Output Arguments
dlY
— Focal cross-entropy loss
dlarray
scalar
Focal cross-entropy loss, returned as a dlarray
scalar without
dimension labels. The output dlY
has the same underlying data type as
the input dlX
.
Version History
Introduced in R2020bR2023b: TargetCategories
is not recommended
TargetCategories
is not recommended. Use
ClassificationMode
instead. To update your code, replace all
instances of TargetCategories="exclusive"
with
ClassificationMode="single-label"
and all instances of
TargetCategories="independent"
with
ClassificationMode="multilabel"
. There are no differences between the
properties that require additional updates to your code. The default behavior of the
focalcrossentropy
function remains the same.
See Also
softmax
(Deep Learning Toolbox) | sigmoid
(Deep Learning Toolbox) | crossentropy
(Deep Learning Toolbox) | mse
(Deep Learning Toolbox)
Topics
- Lidar 3-D Object Detection Using PointPillars Deep Learning (Lidar Toolbox)
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)
Asia Pacific
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)