Define Custom Deep Learning Metric Object
Note
This topic explains how to define custom deep learning metric
objects for your tasks. For a list of built-in metrics in Deep Learning Toolbox™, see Metrics. You can also specify custom metrics using a function handle. For more
information, see Define Custom Metric Function.
In deep learning, a metric is a numerical value that evaluates the performance of a deep learning network. You can use metrics to monitor how well a model is performing by comparing the model predictions to the ground truth. Common deep learning metrics are accuracy, F-score, precision, recall, and root mean squared error.
How To Decide Which Metric Type To Use
If Deep Learning Toolbox does not provide the metric that you need for your task and you cannot use a
function handle, then you can define your own custom metric object using this topic as a guide.
After you define the custom metric, you can specify the metric as the Metrics name-value argument in the trainingOptions
function.
Metric Template
To define a custom metric, use this class definition template as a starting point. For an example that shows how to use this template to create a custom metric, see Define Custom Metric Object.
The template outlines how to specify these aspects of the class definition:
The
propertiesblock for public metric properties. This block must contain theNameproperty.The
propertiesblock for private metric properties. This block is optional.The metric constructor function.
The optional
initializefunction.The required
reset,update,aggregate, andevaluatefunctions.
For information about when the software calls each function, see Function Call Order.
classdef myMetric < deep.Metric properties % (Required) Metric name. Name % Declare public metric properties here. % Any code can access these properties. Include here any properties % that you want to access or edit outside of the class. end properties (Access = private) % (Optional) Metric properties. % Declare private metric properties here. % Only members of the defining class can access these properties. % Include here properties that you do not want to edit outside % the class. end methods function metric = myMetric(args) % Create a myMetric object. % This function must have the same name as the class. % Define metric construction function here. end function metric = initialize(metric,batchY,batchT) % (Optional) Initialize metric. % % Use this function to initialize variables and run validation % checks. % % Inputs: % metric - Metric to initialize % batchY - Mini-batch of predictions % batchT - Mini-batch of targets % % Output: % metric - Initialized metric % % For networks with multiple outputs, replace batchY with % batchY1,...,batchYN and batchT with batchT1,...,batchTN, % where N is the number of network outputs. To create a metric % that supports any number of network outputs, replace batchY % and batchT with varargin. % Define metric initialization function here. end function metric = reset(metric) % Reset metric properties. % % Use this function to reset the metric properties between % iterations. % % Input: % metric - Metric containing properties to reset % % Output: % metric - Metric with reset properties % Define metric reset function here. end function metric = update(metric,batchY,batchT) % Update metric properties. % % Use this function to update metric properties that you use to % compute the final metric value. % % Inputs: % metric - Metric containing properties to update % batchY - Mini-batch of predictions % batchT - Mini-batch of targets % % Output: % metric - Metric with updated properties % % For networks with multiple outputs, replace batchY with % batchY1,...,batchYN and batchT with batchT1,...,batchTN, % where N is the number of network outputs. To create a metric % that supports any number of network outputs, replace batchY % and batchT with varargin. % Define metric update function here. end function metric = aggregate(metric,metric2) % Aggregate metric properties. % % Use this function to define how to aggregate properties from % multiple instances of the same metric object during parallel % training. % % Inputs: % metric - Metric containing properties to aggregate % metric2 - Metric containing properties to aggregate % % Output: % metric - Metric with aggregated properties % % Define metric aggregation function here. end function val = evaluate(metric) % Evaluate metric properties. % % Use this function to define how to use the metric properties % to compute the final metric value. % % Input: % metric - Metric containing properties to use to % evaluate the metric value % % Output: % val - Evaluated metric value % % To return multiple metric values, replace val with val1,... % valN. % Define metric evaluation function here. end end end
Metric Properties
Declare the metric properties in the property sections. You can
specify attributes in the class definition to customize the behavior of properties for specific
purposes. This template defines two property types by setting their Access
attribute. Use the Access attribute to control access to specific class
properties.
properties— Any code can access these properties. This is the default properties block with the default property attributes. By default, theAccessattribute ispublic.properties (Access = private)— Only members of the defining class can access the property.
Public Properties
Declare public metric properties in the properties section of
the class definition. These properties have public access, which
means any code can access the values. By default, custom metrics have the
NetworkOutput public property with the default value
[] and the Maximize public property with
the default value []. The NetworkOutput
property defines which network output to apply the metric to. The
Maximize property sets a flag that defines if the optimal
value for the metric occurs when the metric is maximized (1 or
true) or when the metric is minimized (0 or
false).
You must define the Name property in this block. The
Name property controls the name of the metric in any plots
or command line output.
Private Properties
Declare private metric properties in the properties (Access =
private) section of the class definition. These properties have
private access, which means only members of the defining
class can access these properties. For example, the class functions can access
private properties. If the metric has no private properties, then you can omit this
properties section.
Constructor Function
The constructor function creates the metric and initializes the metric properties. The constructor function must take as input any variables that you need to compute the metric. This function must have the same name as the class.
To use any properties as name-value arguments, you must set them in the constructor
function. All metrics require the optional Name argument.
Tip
To use the NetworkOutput property as a name-value
argument, you must set the property in the constructor function.
Initialization Function
The initialize function is an optional function that the software
calls after reading the first batch of data. You can use this function to initialize
variables and run validation checks.
The initialize function must have this syntax, where
batchY and batchT inputs represent the
mini-batch predictions and targets, respectively. For networks with multiple outputs,
replace batchY with batchY1,...,batchYN and
batchT with batchT1,...,batchTN, where
N is the number of network outputs. To create a metric that
supports any number of network outputs, replace batchY and
batchT with varargin.
metric = initialize(metric,batchY,batchT)
Example initialize Function
This code shows an example of an initialize function that
checks that you are using the metric for a network with a single output and
therefore only one set of batch predictions and
targets.
function metric = initialize(metric,batchY,batchT) if nargin ~= 3 error("Metric not supported for networks with multiple outputs.") end end
Reset Function
The reset function resets the metric properties. The software
calls this function before each iteration. For more information, see Function Call Order.
The reset function must have this
syntax.
metric = reset(metric)
Update Function
The update function updates the metric properties that you use to
compute the metric value. The function calls update during each
training and validation mini-batch. For more information, see Function Call Order.
The update function must have this syntax, where
batchY and batchT inputs represent the
mini-batch predictions and targets, respectively. For networks with multiple outputs,
replace batchY with batchY1,...,batchYN and
batchT with batchT1,...,batchTN, where
N is the number of network outputs. To create a metric that
supports any number of network outputs, replace batchY and
batchT with varargin.
metric = update(metric,batchY,batchT)
For categorical targets, the layout of the targets that the software passes to the metric depends on which function you want to use the metric with.
When using the metric with
trainnetand the targets are categorical arrays, if the loss function is"index-crossentropy", then the software automatically converts the targets to numeric class indices and passes them to the metric. For other loss functions, the software converts the targets to one-hot encoded vectors and passes them to the metric.When using the metric with
testnetand the targets are categorical arrays, if the specified metrics include"index-crossentropy"but do not include"crossentropy", then the software converts the targets to numeric class indices and passes them to the metric. Otherwise, the software converts the targets to one-hot encoded vectors and passes them to the metric.
Aggregation Function
The aggregate function specifies how to combine properties from
multiple instances of the same metric object during parallel training. When you train a
network in parallel, the software divides each training mini-batch into smaller subsets.
For each subset, the software then calls update to update the
metric properties, and then calls aggregate to consolidate the
results for the whole mini-batch. For more information, see Function Call Order.
The aggregate function must have this syntax, where
metric2 input is another instance of the metric. To ensure that
your function always produces the same results, make sure that
aggregate is an associative
function.
metric = aggregate(metric,metric2)
Evaluation Function
The evaluate function specifies how to compute the metric value.
In most cases, the final metric value is a function of the metric properties.
For the training data, the software calls evaluate at the end of
each mini-batch. For the validation data, the software calls
evaluate after all of the data passes through the network.
Therefore, the software computes the metric for each batch of training data but for all
of the validation data. For more information, see Function Call Order.
The evaluate function must have this syntax, where
M is the number of metrics to
return.
[val,...,valM] = evaluate(metric)Function Call Order
The order in which the software calls the initialize,
reset, update,
aggregate, and evaluate functions depends
on where in the training loop the software is. The first function the software calls is
initialize. The software calls initialize
after it reads the first batch of data.
The order in which the software calls the remaining functions depends on whether the data is training or validation data.
Training data — For each mini-batch, the software calls
reset, thenupdate, and thenevaluate. Therefore, the software returns the metric value for each training mini-batch, where each batch is equivalent to a single training iteration.Validation data — For each mini-batch, the software calls
updateonly. The software callsevaluateafter all of the validation data passes through the network. Therefore, the software returns the metric value for the whole validation set (full-batch). This behavior is equivalent to a validation iteration. The software callsresetbefore the first validation mini-batch.
This diagram illustrates the difference between how the software computes the metric for the training and validation data.

Note
When you train a network using the L-BFGS solver, the software processes all of the data in a single batch. This behavior is equivalent to a single mini-batch with all of the observations.
Aggregate Data
The aggregate function defines how to aggregate properties
from multiple instances of the same metric object during parallel training. When you
train a network in parallel, the software divides each training mini-batch into
smaller subsets. For each subset, the software then calls
update to update the metric properties, and then calls
aggregate to consolidate the results for the whole
mini-batch. Finally, the software calls evaluate to obtain the
metric value for the whole training mini-batch.
See Also
trainingOptions | trainnet | dlnetwork
