trainFromData
Syntax
Description
also specifies nondefault training options using the tfdStats
= trainFromData(___,tfdOpts
)rlTrainingFromDataOptions
object trainFDOpts
.
logs training data using the tfdStats
= trainFromData(___,logger=lgr
)FileLogger
object
lgr
.
Examples
Train Agent from Data Collected by Training Another Agent
To collect training data, first, create an environment.
env = rlPredefinedEnv("CartPole-Discrete");
Create a built-in PPO agent with default networks.
agent1 = rlPPOAgent( ... getObservationInfo(env), ... getActionInfo(env));
Create a FileLogger
object.
flgr = rlDataLogger;
To log the experiences on disk, assign an appropriate logger function to the logger object. This function is automatically called by the training loop at the end of each episode, and is defined at the end of the example.
flgr.EpisodeFinishedFcn = @myEpisodeFinishedFcn;
Define a training option object to train agent1
for no more than 100 epochs, without visualizing any training progress.
tOpts = rlTrainingOptions(MaxEpisodes=100,Plots="none");
Train agent1
, logging the experience data.
train(agent1,env,tOpts,Logger=flgr);
At the end of this training, files containing experience data for each episode are saved in the logs
folder.
Note that the only purpose of training agent1
is to collect experience data from the environment. Collecting experiences by simulating the environment in closed loop with a controller (using a for
loop), or indeed collecting a series of observations caused by random actions, would also accomplish the same result.
To allow the trainFromData
function to read the experience data stored in the logs
folder, create a read function that, given a file name, returns the respective experience structure. For this example, the myReadFcn
function is defined at the end of the example.
Check that the function can successfully retrieve data from an episode.
cd logs exp = myReadFcn("loggedData002")
exp=22×1 struct array with fields:
NextObservation
Observation
Action
Reward
IsDone
size(cell2mat([exp.Action]))
ans = 1×2
1 22
cd ..
Create a FileDataStore
object using fileDatastore
. Pass as arguments the name of the folder where files are stored and the read function. The read function is called automatically when the datastore is accessed for reading and is defined at the end of the example.
fds = fileDatastore("./logs", "ReadFcn", @myReadFcn);
Create a built-in DQN agent with default networks to be trained from the collected dataset.
agent2 = rlDQNAgent( ... getObservationInfo(env), ... getActionInfo(env));
Define an options object to train agent2
from data for 50 epochs. Each epoch contains 100 learning steps.
tfdOpts = rlTrainingFromDataOptions(MaxEpochs=50, NumStepsPerEpoch=100);
To train agent2
from data, use trainFromData
. Pass the fileDataStore
object fds
as second input argument.
trainFromData(agent2,fds,tfdOpts);
Here, the estimated Q-value seems to grow indefinitely over time. This often happens during offline training because the agent updates its estimated Q-value based on the current estimated Q-value, without using any environment feedback. To prevent the Q-value from becoming increasingly large (and inaccurate) over time, stop the training earlier or use data regularizer options such as rlConservativeQLearningOptions
(for DQN or SAC agents) or rlBehaviorCloningRegularizerOptions
(for DDPG, TD3 or SAC agents).
In general, the Q-value calculated as above for an agent trained offline is not necessarily indicative of the performance of the agent within an environment. Therefore, best practice is to validate the agent within an environment after offline training.
Support Functions
The data logging function. This function is automatically called by the training loop at the end of each episode, and must return a structure containing the data to log, such as experiences, simulation information, or initial observations. Here, data is a structure that contains the following fields:
EpisodeCount — current episode number
Environment — environment object
Agent — agent object
Experience — structure array containing the experiences. Each element of this array corresponds to a step and is a structure containing the fields
NextObservation
,Observation
,Action
,Reward
andIsDone
.Agent — agent object
EpisodeInfo — structure containing the fields CumulativeReward, StepsTaken and InitialObservation.
SimulationInfo — contains simulation information from the episode. For MATLAB® environments this is a structure with the field SimulationError, and for Simulink® environments it is a Simulink.SimulationOutput object.
function dataToLog = myEpisodeFinishedFcn(data) dataToLog.Experience = data.Experience; end
For more information on logging data on disk, see FileLogger
.
The data store read function. This function is automatically called by the training loop when the data store is accessed for reading. It must take a filename and return the experience structure array. Each element of this array corresponds to a step and is a structure containing the fields NextObservation
, Observation
, Action
, Reward
and IsDone
.
function experiences = myReadFcn(fileName) if contains(fileName,"loggedData") data = load(fileName); experiences = data.episodeData.Experience{1}; else experiences = []; end end
Input Arguments
agent
— Off-policy agent
rlDQNAgent
object | rlDDPGAgent
object | rlTD3Agent
object | rlSACAgent
object | rlMBPOAgent
object
Off-policy agent to train, specified as a reinforcement learning agent object, such
as an rlSACAgent
object.
Note
trainFromData
updates the agent as training progresses. For
more information on how to preserve the original agent, how to save an agent during
training, and on the state of agent
after training, see the notes and
the tips section in train
. For
more information about handle objects, see Handle Object Behavior.
For more information about how to create and configure agents for reinforcement learning, see Reinforcement Learning Agents.
dataStore
— Data store
FileDataStore
object
Data store, specified as a FileDataStore
. The function specified in
the ReadFcn
property of dataStore
must return
a structure array of experiences with the Observation
,
Action
, Reward
,
NextObservation
, and IsDone
fields. The
dimensions of the arrays in Observation
and
NextObservation
in each experience must be the same as the
dimensions specified in the ObservationInfo
of
agent
. The dimension of the array in Action
must be the same as the dimension specified in the ActionInfo
of
agent
. The Reward
and
IsDone
fields must contain scalar values. For more information,
see fileDatastore
.
tfdOpts
— Training from data parameters and options
rlTrainingFromDataOptions
object
Training from data parameters and options, specified as an
rlTrainingFromDataOptions
object. Use this argument to specify
parameters and options such as:
Number of epochs
Number of steps for each epochs
Criteria for saving candidate agents
How to display training progress
Note
trainFromData
does not support parallel computing.
For details, see rlTrainingFromDataOptions
.
lgr
— Logger object
FileLogger
object | MonitorLogger
object
Logger object, specified either as a FileLogger
or as
a MonitorLogger
object. For more information on reinforcement logger objects, see rlDataLogger
.
Output Arguments
tfdStats
— Training results
rlTrainingFromDataResult
object
Training results, returned as an rlTrainingFromDataResult
object,
which has the following properties:
EpochIndex
— Epoch numbers
[1;2;…;N]
Epoch numbers, returned as the column vector [1;2;…;N]
,
where N
is the number of epochs in the training run. This
vector is useful if you want to plot the evolution of other quantities from epoch
to epoch.
EpochSteps
— Number of steps in each epoch
column vector
Number of steps in each epoch, returned as a column vector of length
N
. Each entry contains the number of steps in the
corresponding epoch.
TotalSteps
— Total number of steps
column vector
Total number of agent steps in training, returned as a column vector of length
N
. Each entry contains the cumulative sum of the entries in
EpochSteps
up to that point.
QValue
— Q-value estimates for each epoch
column vector
Q-value estimates for each epoch, returned as a column vector of length
N
. Each element is the average Q-value of the policy, over
the observations specified in the QValueObservations
property
of tfdOpts
, evaluated at the end of the epoch, and using the
policy parameters at the end of the epoch.
Note
During offline training, the agent updates its estimated Q-value based on
the current estimated Q-value (without any environment feedback). As a result,
the estimated Q-value can become inaccurate (and often increasingly large)
over time. To prevent the Q-value from growing indefinitely, stop the training
earlier or use data regularizer options. For more information, see rlBehaviorCloningRegularizerOptions
and rlConservativeQLearningOptions
.
Note
The Q-value calculated as above for an agent trained offline is not indicative of the performance of the agent within an environment. Therefore, it is good practice to validate the agent within an environment after offline training.
TrainingOptions
— Training options set
rlTrainingFromDataOptions
object
Training options set, returned as an rlTrainingFromDataOptions
object.
Version History
Introduced in R2023a
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: United States.
You can also select a web site from the following list
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)
Asia Pacific
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)