Define Custom Deep Learning Metric Object
Note
This topic explains how to define custom deep learning metric
objects for your tasks. For a list of built-in metrics in Deep Learning Toolbox™, see Metrics
. You can also specify custom metrics using a function handle. For more
information, see Define Custom Metric Function.
In deep learning, a metric is a numerical value that evaluates the performance of a deep learning network. You can use metrics to monitor how well a model is performing by comparing the model predictions to the ground truth. Common deep learning metrics are accuracy, F-score, precision, recall, and root mean squared error.
How To Decide Which Metric Type To Use
If Deep Learning Toolbox does not provide the metric that you need for your task and you cannot use a
function handle, then you can define your own custom metric object using this topic as a guide.
After you define the custom metric, you can specify the metric as the Metrics
name-value argument in the trainingOptions
function.
Metric Template
To define a custom metric, use this class definition template as a starting point. For an example that shows how to use this template to create a custom metric, see Define Custom Metric Object.
The template outlines how to specify these aspects of the class definition:
The
properties
block for public metric properties. This block must contain theName
property.The
properties
block for private metric properties. This block is optional.The metric constructor function.
The optional
initialize
function.The required
reset
,update
,aggregate
, andevaluate
functions.
For information about when the software calls each function, see Function Call Order.
classdef myMetric < deep.Metric properties % (Required) Metric name. Name % Declare public metric properties here. % Any code can access these properties. Include here any properties % that you want to access or edit outside of the class. end properties (Access = private) % (Optional) Metric properties. % Declare private metric properties here. % Only members of the defining class can access these properties. % Include here properties that you do not want to edit outside % the class. end methods function metric = myMetric(args) % Create a myMetric object. % This function must have the same name as the class. % Define metric construction function here. end function metric = initialize(metric,batchY,batchT) % (Optional) Initialize metric. % % Use this function to initialize variables and run validation % checks. % % Inputs: % metric - Metric to initialize % batchY - Mini-batch of predictions % batchT - Mini-batch of targets % % Output: % metric - Initialized metric % % For networks with multiple outputs, replace batchY with % batchY1,...,batchYN and batchT with batchT1,...,batchTN, % where N is the number of network outputs. To create a metric % that supports any number of network outputs, replace batchY % and batchT with varargin. % Define metric initialization function here. end function metric = reset(metric) % Reset metric properties. % % Use this function to reset the metric properties between % iterations. % % Input: % metric - Metric containing properties to reset % % Output: % metric - Metric with reset properties % Define metric reset function here. end function metric = update(metric,batchY,batchT) % Update metric properties. % % Use this function to update metric properties that you use to % compute the final metric value. % % Inputs: % metric - Metric containing properties to update % batchY - Mini-batch of predictions % batchT - Mini-batch of targets % % Output: % metric - Metric with updated properties % % For networks with multiple outputs, replace batchY with % batchY1,...,batchYN and batchT with batchT1,...,batchTN, % where N is the number of network outputs. To create a metric % that supports any number of network outputs, replace batchY % and batchT with varargin. % Define metric update function here. end function metric = aggregate(metric,metric2) % Aggregate metric properties. % % Use this function to define how to aggregate properties from % multiple instances of the same metric object during parallel % training. % % Inputs: % metric - Metric containing properties to aggregate % metric2 - Metric containing properties to aggregate % % Output: % metric - Metric with aggregated properties % % Define metric aggregation function here. end function val = evaluate(metric) % Evaluate metric properties. % % Use this function to define how to use the metric properties % to compute the final metric value. % % Input: % metric - Metric containing properties to use to % evaluate the metric value % % Output: % val - Evaluated metric value % % To return multiple metric values, replace val with val1,... % valN. % Define metric evaluation function here. end end end
Metric Properties
Declare the metric properties in the property sections. You can
specify attributes in the class definition to customize the behavior of properties for specific
purposes. This template defines two property types by setting their Access
attribute. Use the Access
attribute to control access to specific class
properties.
properties
— Any code can access these properties. This is the default properties block with the default property attributes. By default, theAccess
attribute ispublic
.properties (Access = private)
— Only members of the defining class can access the property.
Public Properties
Declare public metric properties in the properties
section of
the class definition. These properties have public
access, which
means any code can access the values. By default, custom metrics have the
NetworkOutput
public property with the default value
[]
and the Maximize
public property with
the default value []
. The NetworkOutput
property defines which network output to apply the metric to. The
Maximize
property sets a flag that defines if the optimal
value for the metric occurs when the metric is maximized (1 or
true
) or when the metric is minimized (0 or
false
).
You must define the Name
property in this block. The
Name
property controls the name of the metric in any plots
or command line output.
Private Properties
Declare private metric properties in the properties (Access =
private)
section of the class definition. These properties have
private
access, which means only members of the defining
class can access these properties. For example, the class functions can access
private properties. If the metric has no private properties, then you can omit this
properties
section.
Constructor Function
The constructor function creates the metric and initializes the metric properties. The constructor function must take as input any variables that you need to compute the metric. This function must have the same name as the class.
To use any properties as name-value arguments, you must set them in the constructor
function. All metrics require the optional Name
argument.
Tip
To use the NetworkOutput
property as a name-value
argument, you must set the property in the constructor function.
Initialization Function
The initialize
function is an optional function that the software
calls after reading the first batch of data. You can use this function to initialize
variables and run validation checks.
The initialize
function must have this syntax, where
batchY
and batchT
inputs represent the
mini-batch predictions and targets, respectively. For networks with multiple outputs,
replace batchY
with batchY1,...,batchYN
and
batchT
with batchT1,...,batchTN
, where
N
is the number of network outputs. To create a metric that
supports any number of network outputs, replace batchY
and
batchT
with varargin
.
metric = initialize(metric,batchY,batchT)
Example initialize
Function
This code shows an example of an initialize
function that
checks that you are using the metric for a network with a single output and
therefore only one set of batch predictions and
targets.
function metric = initialize(metric,batchY,batchT) if nargin ~= 3 error("Metric not supported for networks with multiple outputs.") end end
Reset Function
The reset
function resets the metric properties. The software
calls this function before each iteration. For more information, see Function Call Order.
The reset
function must have this
syntax.
metric = reset(metric)
Update Function
The update
function updates the metric properties that you use to
compute the metric value. The function calls update
during each
training and validation mini-batch. For more information, see Function Call Order.
The update
function must have this syntax, where
batchY
and batchT
inputs represent the
mini-batch predictions and targets, respectively. For networks with multiple outputs,
replace batchY
with batchY1,...,batchYN
and
batchT
with batchT1,...,batchTN
, where
N
is the number of network outputs. To create a metric that
supports any number of network outputs, replace batchY
and
batchT
with varargin
.
metric = update(metric,batchY,batchT)
For categorical targets, the layout of the targets that the software passes to the metric depends on which function you want to use the metric with.
When using the metric with
trainnet
and the targets are categorical arrays, if the loss function is"index-crossentropy"
, then the software automatically converts the targets to numeric class indices and passes them to the metric. For other loss functions, the software converts the targets to one-hot encoded vectors and passes them to the metric.When using the metric with
testnet
and the targets are categorical arrays, if the specified metrics include"index-crossentropy"
but do not include"crossentropy"
, then the software converts the targets to numeric class indices and passes them to the metric. Otherwise, the software converts the targets to one-hot encoded vectors and passes them to the metric.
Aggregation Function
The aggregate
function specifies how to combine properties from
multiple instances of the same metric object during parallel training. When you train a
network in parallel, the software divides each training mini-batch into smaller subsets.
For each subset, the software then calls update
to update the
metric properties, and then calls aggregate
to consolidate the
results for the whole mini-batch. For more information, see Function Call Order.
The aggregate
function must have this syntax, where
metric2
input is another instance of the metric. To ensure that
your function always produces the same results, make sure that
aggregate
is an associative
function.
metric = aggregate(metric,metric2)
Evaluation Function
The evaluate
function specifies how to compute the metric value.
In most cases, the final metric value is a function of the metric properties.
For the training data, the software calls evaluate
at the end of
each mini-batch. For the validation data, the software calls
evaluate
after all of the data passes through the network.
Therefore, the software computes the metric for each batch of training data but for all
of the validation data. For more information, see Function Call Order.
The evaluate
function must have this syntax, where
M
is the number of metrics to
return.
[val,...,valM] = evaluate(metric)
Function Call Order
The order in which the software calls the initialize
,
reset
, update
,
aggregate
, and evaluate
functions depends
on where in the training loop the software is. The first function the software calls is
initialize
. The software calls initialize
after it reads the first batch of data.
The order in which the software calls the remaining functions depends on whether the data is training or validation data.
Training data — For each mini-batch, the software calls
reset
, thenupdate
, and thenevaluate
. Therefore, the software returns the metric value for each training mini-batch, where each batch is equivalent to a single training iteration.Validation data — For each mini-batch, the software calls
update
only. The software callsevaluate
after all of the validation data passes through the network. Therefore, the software returns the metric value for the whole validation set (full-batch). This behavior is equivalent to a validation iteration. The software callsreset
before the first validation mini-batch.
This diagram illustrates the difference between how the software computes the metric for the training and validation data.
Note
When you train a network using the L-BFGS solver, the software processes all of the data in a single batch. This behavior is equivalent to a single mini-batch with all of the observations.
Aggregate Data
The aggregate
function defines how to aggregate properties
from multiple instances of the same metric object during parallel training. When you
train a network in parallel, the software divides each training mini-batch into
smaller subsets. For each subset, the software then calls
update
to update the metric properties, and then calls
aggregate
to consolidate the results for the whole
mini-batch. Finally, the software calls evaluate
to obtain the
metric value for the whole training mini-batch.
See Also
trainingOptions
| trainnet
| dlnetwork