Main Content

focalCrossEntropy

Compute focal cross-entropy loss

Since R2020b

Description

dlY = focalCrossEntropy(dlX,targets) computes the focal cross-entropy between network predictions and target values for single-label and multi-label classification tasks. The classes are mutually-exclusive classes. The focal cross-entropy loss weights towards poorly classified training samples and ignores well-classified samples. The focal cross-entropy loss is computed as the average logarithmic loss divided by number of non-zero targets.

example

dlY = focalCrossEntropy(dlX,targets,"DataFormat",FMT) also specifies the dimension format FMT when dlX is not a formatted dlarray.

example

dlY = focalCrossEntropy(___,Name=Value) specifies options using one or more name-value arguments in addition to the input arguments in previous syntaxes. For example, ClassificationMode="multilabel" computes the cross-entropy loss for a multi-label classification task.

example

Examples

collapse all

Create the input classification data as 32 observations of random variables belonging to 10 classes or categories.

numCategories = 10;
observations = 32;
X = rand(numCategories,observations);

Create a formatted deep learning array that has a data format with the labels 'C' and 'B'.

dlX = dlarray(X,'CB');

Use the softmax function to set all values in the input data to values between 0 and 1 that sum to 1 over all channels. The values specify the probability of each observation to belong to a particular category.

dlX = softmax(dlX);

Create the target data as unformatted deep learning array, which holds the correct category for each observation in dlX. Set the targets belonging to the second category as one-hot encoded vectors.

targets = dlarray(zeros(numCategories,observations));
targets(2,:) = 1;

Compute the focal cross-entropy loss between each prediction and the target.

dlY = focalCrossEntropy(dlX,targets,'Reduction','none');

Create the input classification data as 32 observations of random variables belonging to 10 classes or categories.

numCategories = 10;
observations = 32;
X = rand(numCategories,observations);

Create an unformatted deep learning array.

dlX = dlarray(X);

Use the softmax function to set all values in the input data to values between 0 and 1 that sum to 1 over all channels. The values specify the probability for each observation to belong to a particular category.

dlX = softmax(dlX,'DataFormat','CB');

Create the target data. Set the targets belonging to the second category as one-hot encoded vectors.

targets = zeros(numCategories,observations);
targets(2,:) = 1;

Compute the average of focal cross-entropy loss computed between the predictions and the targets.

dlY = focalCrossEntropy(dlX,targets,'DataFormat','CB')
dlY = 
  1x1 dlarray

    0.4769

Create the input classification data as 32 observations of random variables belonging to 10 classes or categories.

numCategories = 10;
observations = 32;
X = rand(numCategories,observations);

Create a formatted deep learning array that has a data format with the labels 'C' and 'B'.

dlX = dlarray(X,'CB');

Use the sigmoid function to set all values in the input data to values between 0 and 1 that sum to 1 over all channels. The values specify the probability of each observation to belong to a particular category.

dlX = sigmoid(dlX);

Create the target data, which holds the correct category for each observation in dlX. Set the targets belonging to the second and sixth category as one-hot encoded vectors.

targets = zeros(numCategories,observations);
targets(2,:) = 1;
targets(6,:) = 1;

Compute the average of focal cross-entropy loss computed between the predictions and the targets. Set the 'ClassificationMode' value to 'multilabel' for multi-label classification.

dlY = focalCrossEntropy(dlX,targets,'ClassificationMode','multilabel')
dlY = 
  1x1 dlarray

    2.4362

Input Arguments

collapse all

Predictions, specified as a dlarray with or without dimension labels or a numeric array. When dlX is not a formatted dlarray, you must specify the dimension format using the DataFormat name-value argument. If dlX is a numeric array, targets must be a dlarray.

Data Types: single | double

Target classification labels, specified as a formatted or unformatted dlarray or a numeric array.

If targets is a formatted dlarray, its dimension format must be the same as the format of dlX, or the same as the DataFormat name-value argument if dlX is unformatted.

If targets is an unformatted dlarray or a numeric array, the size of targets must exactly match the size of dlX. The format of dlX or the value of DataFormat is implicitly applied to targets.

Data Types: single | double

Name-Value Arguments

Specify optional pairs of arguments as Name1=Value1,...,NameN=ValueN, where Name is the argument name and Value is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.

Example: ClassificationMode="multilabel",DataFormat="CB" evaluates the focal cross-entropy loss for multi-label classification tasks and specifies the dimension order of the input data as "CB"

Before R2021a, use commas to separate each name and value, and enclose Name in quotes.

Example: "ClassificationMode","multilabel","DataFormat","CB" evaluates the focal cross-entropy loss for multi-label classification tasks and specifies the dimension order of the input data as "CB"

Focusing parameter of the focal loss function, specified as a positive real number. Increasing the value of Gamma increases the sensitivity of the network to misclassified observations.

Balancing parameter of the focal loss function, specified as a positive real number. The Alpha value scales the loss function linearly and is typically set to 0.25. If you decrease Alpha, increase Gamma.

Type of output loss, specified as one of the following:

  • "mean" — Average of output loss for each prediction. The function computes the average of loss values computed for each predictions in input dlX. The function returns the average loss as an unformatted dlarray. Observations with all zero target values along the channel dimension are excluded from computing the average loss.

  • "none" — Output loss for each prediction. The function returns the loss values for each observation in dlX. The samples for computing focal cross-entropy loss also contains observations whose target values are all zeros along the channel dimension. If dlX is a formatted dlarray, output dlY is a formatted dlarray with same dimension labels as dlX. If dlX is an unformatted dlarray, output dlY is an unformatted dlarray.

Data Types: char | string

Type of classification task, specified as one of these values:

  • "single-label" — Each observation in the predictions dlX is exclusively assigned to one category (single-label classification).

  • "multilabel"— Each observation in the predictions dlX can be assigned to one or more independent categories (multilabel classification).

Note

To select the classification mode for binary classification, you must consider the output layer of the network:

  • If the final layer has an output size of one, such as with a sigmoid layer, use "multilabel".

  • If the final layer has an output size of two, such as with a softmax layer, use "single-label".

Dimension order of unformatted input data, specified as a character vector or string FMT that provides a label for each dimension of the data. Each character in FMT must be one of the following:

  • "S" — Spatial

  • "C" — Channel

  • "B" — Batch (for example, samples and observations)

  • "T" — Time (for example, sequences)

  • "U" — Unspecified

You can specify multiple dimensions labeled "S" or "U". You can use the labels "C", "B", and "T" at most once.

You must specify DataFormat when the input data dlX is not a formatted dlarray.

Example: DataFomat="SSCB"

Data Types: char | string

Output Arguments

collapse all

Focal cross-entropy loss, returned as a dlarray scalar without dimension labels. The output dlY has the same underlying data type as the input dlX.

Version History

Introduced in R2020b

expand all

See Also

(Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox)