rlMBPOAgent
Description
A model-based policy optimization (MBPO) agent is a model-based, online, off-policy, reinforcement learning method. An MBPO agent contains an internal model of the environment, which it uses to generate additional experiences without interacting with the environment. The action space can be either discrete or continuous depending on the base agent.
During training, the MBPO agent generates real experiences by interacting with the environment. These experiences are used to train the internal environment model, which is used to generate additional experiences. The training algorithm then uses both the real and generated experiences to update the agent policy.
Note
MBPO agents do not support recurrent networks.
Creation
Description
creates a model-based policy optimization agent with default options and sets the
agent
= rlMBPOAgent(baseAgent
,envModel
)BaseAgent
and EnvModel
properties.
creates a model-based policy optimization agent using specified options and sets the
agent
= rlMBPOAgent(___,agentOptions
)AgentOptions
property.
Properties
BaseAgent
— Base reinforcement learning agent
rlDQNAgent
| rlDDPGAgent
| rlTD3Agent
| rlSACAgent
Base reinforcement learning agent, specified as an off-policy agent object.
For environments with a discrete action space, specify a DQN agent using an
rlDQNAgent
object.
For environments with a continuous action space, use one of the following agent objects.
rlDDPGAgent
— DDPG agentrlTD3Agent
— TD3 agentrlSACAgent
— SAC agent
EnvModel
— Environment model
rlNeuralNetworkEnvironment
Environment model, specified as an rlNeuralNetworkEnvironment
object. This environment contains transition
functions, a reward function, and an is-done function.
AgentOptions
— Agent options
rlMBPOAgentOptions
object
Agent options, specified as an rlMBPOAgentOptions
object.
RolloutHorizon
— Current roll-out horizon value
positive integer
Current roll-out horizon value, specified as a positive integer. For more
information on setting the initial horizon value and the horizon update method, see
rlMBPOAgentOptions
.
ModelExperienceBuffer
— Model experience buffer
rlReplayMemory
object
Model experience buffer, specified as an rlReplayMemory
object. During training the agent stores each of its generated experiences
(S,A,R,S',D)
in a buffer. Here:
S is the current observation of the environment.
A is the action taken by the agent.
R is the reward for taking action A.
S' is the next observation after taking action A.
D is the is-done signal after taking action A.
UseExplorationPolicy
— Option to use exploration policy
true
| false
Option to use exploration policy when selecting actions, specified as one of the following logical values.
true
— Use the base agent exploration policy when selecting actions.false
— Use the base agent greedy policy when selecting actions.
The initial value of UseExplorationPolicy
matches the value
specified in BaseAgent
. If you change the value of
UseExplorationPolicy
in either the base agent or the MBPO agent,
the same value is used for the other agent.
ObservationInfo
— Observation specifications
specification object | array of specification objects
This property is read-only.
Observation specifications, specified as an rlFiniteSetSpec
or rlNumericSpec
object or an array containing a mix of such objects. Each element in the array defines
the properties of an environment observation channel, such as its dimensions, data type,
and name.
The value of ObservationInfo
matches the corresponding value
specified in BaseAgent
.
ActionInfo
— Action specification
rlFiniteSetSpec
object | rlNumericSpec
object
This property is read-only.
Action specifications, specified either as an rlFiniteSetSpec
(for discrete action spaces) or rlNumericSpec
(for continuous action spaces) object. This object defines the properties of the
environment action channel, such as its dimensions, data type, and name.
Note
Only one action channel is allowed.
The value of ActionInfo
matches the corresponding value
specified in BaseAgent
.
SampleTime
— Sample time of agent
1
(default) | positive scalar | -1
Sample time of agent, specified as a positive scalar or as -1
. Setting this
parameter to -1
allows for event-based simulations.
Within a Simulink® environment, the RL Agent block
in which the agent is specified to execute every SampleTime
seconds
of simulation time. If SampleTime
is -1
, the
block inherits the sample time from its parent subsystem.
Within a MATLAB® environment, the agent is executed every time the environment advances. In
this case, SampleTime
is the time interval between consecutive
elements in the output experience returned by sim
or
train
. If
SampleTime
is -1
, the time interval between
consecutive elements in the returned output experience reflects the timing of the event
that triggers the agent execution.
Example: SampleTime=-1
Object Functions
Examples
Create MBPO Agent
Create an environment interface and extract observation and action specifications.
env = rlPredefinedEnv("CartPole-Continuous");
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
Create a base off-policy agent. For this example, use a SAC agent.
agentOpts = rlSACAgentOptions; agentOpts.MiniBatchSize = 256; initOpts = rlAgentInitializationOptions(NumHiddenUnit=64); baseagent = rlSACAgent(obsInfo,actInfo,initOpts,agentOpts);
Check your agent with a random input observation.
getAction(baseagent,{rand(obsInfo.Dimension)})
ans = 1x1 cell array
{[-7.2875]}
The neural network environment uses a function approximator object to approximate the environment transition function. The function approximator object uses one or more neural networks as approximator model. To account for modeling uncertainty, you can specify multiple transition models. For this example, create a single transition model.
Create a neural network to use as approximation model within the transition function object. Define each network path as an array of layer objects. Specify a name for the input and output layers, so you can later explicitly associate them with the appropriate channel.
% Observation and action paths obsPath = featureInputLayer(obsInfo.Dimension(1),Name="obsIn"); actionPath = featureInputLayer(actInfo.Dimension(1),Name="actIn"); % Common path: concatenate along dimension 1 commonPath = [concatenationLayer(1,2,Name="concat") fullyConnectedLayer(64) reluLayer fullyConnectedLayer(64) reluLayer fullyConnectedLayer(obsInfo.Dimension(1),Name="nextObsOut")]; % Add layers to layerGraph object transNet = layerGraph(obsPath); transNet = addLayers(transNet,actionPath); transNet = addLayers(transNet,commonPath); % Connect layers transNet = connectLayers(transNet,"obsIn","concat/in1"); transNet = connectLayers(transNet,"actIn","concat/in2"); % Convert to dlnetwork object transNet = dlnetwork(transNet); % Display number of weights summary(transNet)
Initialized: true Number of learnables: 4.8k Inputs: 1 'obsIn' 4 features 2 'actIn' 1 features
Create the transition function approximator object.
transitionFcnAppx = rlContinuousDeterministicTransitionFunction( ... transNet,obsInfo,actInfo,... ObservationInputNames="obsIn",... ActionInputNames="actIn",... NextObservationOutputNames="nextObsOut");
Create a neural network to use as a reward model for the reward function approximator object.
% Observation and action paths actionPath = featureInputLayer(actInfo.Dimension(1),Name="actIn"); nextObsPath = featureInputLayer(obsInfo.Dimension(1),Name="nextObsIn"); % Common path: concatenate along dimension 1 commonPath = [concatenationLayer(1,2,Name="concat") fullyConnectedLayer(64) reluLayer fullyConnectedLayer(64) reluLayer fullyConnectedLayer(64) reluLayer fullyConnectedLayer(1)]; % Add layers to layerGraph object rewardNet = layerGraph(nextObsPath); rewardNet = addLayers(rewardNet,actionPath); rewardNet = addLayers(rewardNet,commonPath); % Connect layers rewardNet = connectLayers(rewardNet,"nextObsIn","concat/in1"); rewardNet = connectLayers(rewardNet,"actIn","concat/in2"); % Convert to dlnetwork object rewardNet = dlnetwork(rewardNet); % Display number of weights summary(transNet)
Initialized: true Number of learnables: 4.8k Inputs: 1 'obsIn' 4 features 2 'actIn' 1 features
Create the reward function approximator object.
rewardFcnAppx = rlContinuousDeterministicRewardFunction( ... rewardNet,obsInfo,actInfo, ... ActionInputNames="actIn",... NextObservationInputNames="nextObsIn");
Create an is-done model for the reward function approximator object.
% Define main path net = [featureInputLayer(obsInfo.Dimension(1),Name="nextObsIn"); fullyConnectedLayer(64) reluLayer fullyConnectedLayer(64) reluLayer fullyConnectedLayer(2) softmaxLayer(Name="isdoneOut")]; % Convert to layergraph object isDoneNet = layerGraph(net); % Convert to dlnetwork object isDoneNet = dlnetwork(isDoneNet); % Display number of weights summary(transNet)
Initialized: true Number of learnables: 4.8k Inputs: 1 'obsIn' 4 features 2 'actIn' 1 features
Create the reward function approximator object.
isdoneFcnAppx = rlIsDoneFunction(isDoneNet,obsInfo,actInfo, ... NextObservationInputNames="nextObsIn");
Create the neural network environment using the observation and action specifications and the three function approximator objects.
generativeEnv = rlNeuralNetworkEnvironment( ... obsInfo,actInfo,... transitionFcnAppx,rewardFcnAppx,isdoneFcnAppx);
Specify options for creating an MBPO agent. Specify the optimizer options for the transition network and use default values for all other options.
MBPOAgentOpts = rlMBPOAgentOptions; MBPOAgentOpts.TransitionOptimizerOptions = rlOptimizerOptions(... LearnRate=1e-4,... GradientThreshold=1.0);
Create the MBPO agent.
agent = rlMBPOAgent(baseagent,generativeEnv,MBPOAgentOpts);
Check your agent with a random input observation.
getAction(agent,{rand(obsInfo.Dimension)})
ans = 1x1 cell array
{[7.8658]}
Version History
Introduced in R2022a
Open Example
You have a modified version of this example. Do you want to open this example with your edits?
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)