Define Custom Learning Rate Schedule
When you train a neural network using the trainnet function, the
LearnRateSchedule argument of
the trainingOptions function provides several options for customizing the
learning rate schedule. It provides built-in schedules such as
"piecewise" and "warmup". You can also use
learning rate schedule objects that you can customize further such as piecewiseLearnRate and warmupLearnRate
objects. If these built-in options do not provide the functionality that you need, you can
specify the learning rate schedule as a function of the epoch number using a function
handle.
If you need additional flexibility, for example, you want to use a learning rate schedule that changes the learning rate between iterations, or you want to use a learning rate schedule that requires updating and maintaining a state, then you can define your own custom learning rate schedule object using this example as a guide.
To define a custom learning rate schedule, you can use the template provided in this example, which takes you through these steps:
Name the schedule — Give the schedule a name so that you can use it in MATLAB®.
Declare the schedule properties (optional) — Specify the properties of the schedule.
Create the constructor function (optional) — Specify how to construct the schedule and initialize its properties.
Create the update function — Specify how the schedule calculates the learning rate.
This example shows how to define a time-based decay learning rate schedule and use it to train a neural network. A time-based decay learning rate schedule object updates the learning rate every iteration using a decay rule.
The time-based decay learning rate schedule uses this formula to calculate the learning rate:
where:
k is the iteration number.
α0 is the base learning rate, specified by the
InitialLearnRateoption of thetrainingOptionsfunction.
Custom Learning Rate Schedule Template
Copy the custom learning rates schedule template into a new file in MATLAB. This template gives the structure of a schedule class definition. It outlines:
The optional
propertiesblock for the schedule properties.The optional schedule constructor function.
The
updatefunction.
classdef myLearnRateSchedule < deep.LearnRateSchedule properties % (Optional) Schedule properties. % Declare schedule properties here. end methods function schedule = myLearnRateSchedule() % (Optional) Create a myLearnRateSchedule. % This function must have the same name as the class. % Define schedule constructor function here. end function [schedule,learnRate] = update(schedule,initialLearnRate,iteration,epoch) %UPDATE Update learning rate schedule % Define schedule update function here. end end end
Name Schedule and Specify Superclass
First, give the schedule a name. In the first line of the class file, replace the
existing name myLearnRateSchedule with
timeBasedDecayLearnRate.
classdef timeBasedDecayLearnRate < deep.LearnRateSchedule ... end
Next, rename the myLearnRateSchedule constructor function (the
first function in the methods section) so that it has the same name
as the schedule.
methods function schedule = timeBasedDecayLearnRate() ... end ... end
Save the Schedule
Save the schedule class file in a new file named
timeBasedDecayLearnRate.m. The file name must match the
schedule name. To use the schedule, you must save the file in the current folder or
in a folder on the MATLAB path.
Declare Properties
Declare the schedule properties in the properties section.
By default, custom learning rate schedules have these properties. Do not declare these
properties in the properties section.
| Property | Description |
|---|---|
FrequencyUnit | How often the schedule updates the learning rate, specified as
If |
NumSteps | Number of steps the learning rate schedule takes before it is
complete, specified as a positive integer or Inf. For
learning rate schedules that continue indefinitely (also known as
infinite learning rate schedules), this property is
Inf. |
A time-based decay learning rate schedule requires one additional property: the decay
value. Declare the decay value in the properties
block.
properties
% Schedule properties
Decay
endCreate Constructor Function
Create the function that constructs the schedule and initializes the schedule properties. Specify any variables required to create the schedule as inputs to the constructor function.
The time-based decay learning rate schedule constructor function requires one argument
(the decay). Specify one input argument named decay in the
timeBasedDecayLearnRate function that corresponds to the decay.
Add a comment to the top of the function that explains the syntax of the
function.
function schedule = timeBasedDecayLearnRate(decay) % timeBasedDecayLearnRate Time-based decay learning rate % schedule % schedule = timeBasedDecayLearnRate(decay) creates a % time-based decay learning rate schedule with the specified % decay. ... end
Initialize Schedule Properties
Initialize the schedule properties in the constructor function. Replace the
comment % Define schedule constructor function here with code
that initializes the schedule properties.
Because the time-based decay learning rate schedule updates the learning rate each iteration, set the
FrequencyUnitproperty to"iteration".Because the time-based decay learning rate schedule is infinite, set the
NumStepsproperty toInf.Set the schedule
Decayproperty to thedecayargument.
% Set schedule properties. schedule.FrequencyUnit = "iteration"; schedule.NumSteps = Inf; schedule.Decay = decay;
View the completed constructor function.
function schedule = timeBasedDecayLearnRate(decay)
% timeBasedDecayLearnRate Time-based decay learning rate
% schedule
% schedule = timeBasedDecayLearnRate(decay) creates a
% time-based decay learning rate schedule with the specified
% decay.
% Set schedule properties.
schedule.FrequencyUnit = "iteration";
schedule.NumSteps = Inf;
schedule.Decay = decay;
endWith this constructor function, the command
timeBasedDecayLearnRate(0.01) creates a time-based decay
learning rate schedule with a decay value of 0.01.
Create Update Function
Create the function that updates the learning rate.
Create a function named update that updates the learning rate
schedule properties and also returns the calculated learning rate value.
The update function has the syntax [schedule,learnRate]
= update(schedule,initialLearnRate,iteration,epoch), where:
scheduleis an instance of the learning rate schedule.learnRateis the calculated learning rate value.initialLearnRateis the initial learning rate.iterationis the iteration number.epochis the epoch number.
The time-based decay learning rate schedule uses this formula to calculate the learning rate:
where:
k is the iteration number.
α0 is the base learning rate, specified by the
InitialLearnRateoption of thetrainingOptionsfunction.
Implement this operation in update. The schedule does not require
updating any state values, so the output schedule is unchanged.
Because a time-based decay learning rate schedule does not require the epoch number,
the syntax for update for the schedule is
[schedule,learnRate] =
update(schedule,initialLearnRate,iteration,~). Because the time-based
decay learning rate schedule is not finite, there is no need to update the
IsDone property.
function [schedule,learnRate] = update(schedule,initialLearnRate,iteration,~)
% UPDATE Update learning rate schedule
% [schedule,learnRate] = update(schedule,initialLearnRate,iteration,~)
% calculates the learning rate for the specified iteration
% and also returns the updated schedule object.
% Calculate learning rate.
decay = schedule.Decay;
learnRate = initialLearnRate / (1 + decay*(iteration-1));
endCompleted Learning Rate Schedule
Vie the completed learning rate schedule class file.
classdef timeBasedDecayLearnRate < deep.LearnRateSchedule
% timeBasedDecayLearnRate Time-based decay learning rate schedule
properties
% Schedule properties
Decay
end
methods
function schedule = timeBasedDecayLearnRate(decay)
% timeBasedDecayLearnRate Time-based decay learning rate
% schedule
% schedule = timeBasedDecayLearnRate(decay) creates a
% time-based decay learning rate schedule with the specified
% decay.
% Set schedule properties.
schedule.FrequencyUnit = "iteration";
schedule.NumSteps = Inf;
schedule.Decay = decay;
end
function [schedule,learnRate] = update(schedule,initialLearnRate,iteration,~)
% UPDATE Update learning rate schedule
% [schedule,learnRate] = update(schedule,initialLearnRate,iteration,~)
% calculates the learning rate for the specified iteration
% and also returns the updated schedule object.
% Calculate learning rate.
decay = schedule.Decay;
learnRate = initialLearnRate / (1 + decay*(iteration-1));
end
end
endTrain Using Custom Learning Rate Schedule Object
You can use a custom learning rate schedule object in the same way as any other learning rate schedule object in the trainingOptions function. This example shows how to create and train a network for digit classification using a time-based decay learning rate schedule object you defined earlier.
Load the example training data.
load DigitsDataTrainCreate a layer array.
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(5,20)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(10)
softmaxLayer];Create an instance of a time-based decay learning rate schedule object with a decay value of 0.01.
schedule = timeBasedDecayLearnRate(0.01)
schedule =
timeBasedDecayLearnRate with properties:
Decay: 0.0100
FrequencyUnit: "iteration"
NumSteps: Inf
Specify the training options. To train using the learning rate schedule object, set the LearnRateSchedule training option to the object.
options = trainingOptions("sgdm", ... MaxEpochs=10, ... LearnRateSchedule=schedule, ... Metrics="accuracy");
Train the neural network using the trainnet function. For classification, use index cross-entropy loss. By default, the trainnet function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainnet function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.
To get the information about training, such as the learning rate value for each iteration, use the second output of the trainnet function.
[net,info] = trainnet(XTrain,labelsTrain,layers,"index-crossentropy",options); Iteration Epoch TimeElapsed LearnRate TrainingLoss TrainingAccuracy
_________ _____ ___________ _________ ____________ ________________
1 1 00:00:01 0.01 2.5434 10.938
50 2 00:00:04 0.0067114 0.36951 89.062
100 3 00:00:07 0.0050251 0.18465 94.531
150 4 00:00:10 0.0040161 0.092117 98.438
200 6 00:00:13 0.0033445 0.079264 99.219
250 7 00:00:17 0.0028653 0.062432 99.219
300 8 00:00:20 0.0025063 0.033868 100
350 9 00:00:24 0.0022272 0.049939 100
390 10 00:00:26 0.002045 0.047319 100
Training stopped: Max epochs completed
Extract the learning rate information from the training information and visualize it in a plot.
figure plot(info.TrainingHistory.LearnRate) ylim([0 inf]) xlabel("Iteration") ylabel("Learning Rate")

Test the neural network using the testnet function. For single-label classification, evaluate the accuracy. The accuracy is the percentage of correct predictions. By default, the testnet function uses a GPU if one is available. To select the execution environment manually, use the ExecutionEnvironment argument of the testnet function.
load DigitsDataTest classNames = categories(labelsTest); accuracy = testnet(net,XTest,labelsTest,"accuracy")
accuracy = 97.7000
See Also
trainingOptions | trainnet | dlnetwork | piecewiseLearnRate | warmupLearnRate | polynomialLearnRate | exponentialLearnRate | cosineLearnRate | cyclicalLearnRate