rlDQNAgent

Deep Q-network reinforcement learning agent

Description

The deep Q-network (DQN) algorithm is a model-free, online, off-policy reinforcement learning method. A DQN agent is a value-based reinforcement learning agent that trains a critic to estimate the return or future rewards. DQN is a variant of Q-learning.

For more information, Deep Q-Network Agents.

For more information on the different types of reinforcement learning agents, see Reinforcement Learning Agents.

Creation

Description

example

agent = rlDQNAgent(critic,agentOptions) creates a DQN agent with the specified critic network and sets the AgentOptions property.

Input Arguments

expand all

Critic network representation, specified as an rlQValueRepresentation object. For more information on creating critic representations, see Create Policy and Value Function Representations.

Your critic representation can use a recurrent neural network as its function approximator. However, only the multi-output Q-value function representation supports recurrent neural networks. For an example, see Create DQN Agent with Recurrent Neural Network.

Properties

expand all

Agent options, specified as an rlDQNAgentOptions object.

Experience buffer, specified as an ExperienceBuffer object. During training the agent stores each of its experiences (S,A,R,S') in a buffer. Here:

  • S is the current observation of the environment.

  • A is the action taken by the agent.

  • R is the reward for taking action A.

  • S' is the next observation after taking action A.

For more information on how the agent samples experience from the buffer during training, see Deep Q-Network Agents.

Object Functions

trainTrain a reinforcement learning agent within a specified environment
simSimulate a trained reinforcement learning agent within a specified environment
getActorGet actor representation from reinforcement learning agent
setActorSet actor representation of reinforcement learning agent
getCriticGet critic representation from reinforcement learning agent
setCriticSet critic representation of reinforcement learning agent
generatePolicyFunctionCreate function that evaluates trained policy of reinforcement learning agent

Examples

collapse all

Create an environment interface and obtain its observation and action specifications. For this environment load the predefined environment used for the discrete cart-pole system.

% load predefined environment
env = rlPredefinedEnv("CartPole-Discrete");

% get observation and specification info
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

Create a critic representation.

% create a critic network to be used as underlying approximator
statePath = [
    imageInputLayer([4 1 1], 'Normalization', 'none', 'Name', 'state')
    fullyConnectedLayer(24, 'Name', 'CriticStateFC1')
    reluLayer('Name', 'CriticRelu1')
    fullyConnectedLayer(24, 'Name', 'CriticStateFC2')];
actionPath = [
    imageInputLayer([1 1 1], 'Normalization', 'none', 'Name', 'action')
    fullyConnectedLayer(24, 'Name', 'CriticActionFC1')];
commonPath = [
    additionLayer(2,'Name', 'add')
    reluLayer('Name','CriticCommonRelu')
    fullyConnectedLayer(1, 'Name', 'output')];
criticNetwork = layerGraph(statePath);
criticNetwork = addLayers(criticNetwork, actionPath);
criticNetwork = addLayers(criticNetwork, commonPath);    
criticNetwork = connectLayers(criticNetwork,'CriticStateFC2','add/in1');
criticNetwork = connectLayers(criticNetwork,'CriticActionFC1','add/in2');

% set some options for the critic
criticOpts = rlRepresentationOptions('LearnRate',0.01,'GradientThreshold',1);

% create the critic based on the network approximator
critic = rlQValueRepresentation(criticNetwork,obsInfo,actInfo,...
    'Observation',{'state'},'Action',{'action'},criticOpts);

Specify agent options, and create a DQN agent using the environment and critic.

agentOpts = rlDQNAgentOptions(...
    'UseDoubleDQN',false, ...    
    'TargetUpdateMethod',"periodic", ...
    'TargetUpdateFrequency',4, ...   
    'ExperienceBufferLength',100000, ...
    'DiscountFactor',0.99, ...
    'MiniBatchSize',256);

agent = rlDQNAgent(critic,agentOpts)
agent = 
  rlDQNAgent with properties:

        AgentOptions: [1x1 rl.option.rlDQNAgentOptions]
    ExperienceBuffer: [1x1 rl.util.ExperienceBuffer]

To check your agent, use getAction to return the action from a random observation.

getAction(agent,{rand(4,1)})
ans = 10

You can now test and train the agent against the environment.

Create an environment and obtain observation and action information.

env = rlPredefinedEnv('CartPole-Discrete');
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
numObs = obsInfo.Dimension(1);
numDiscreteAct = numel(actInfo.Elements);

Create a recurrent deep neural network for your critic. To create a recurrent neural network, use a sequenceInputLayer as the input layer and include an lstmLayer as one of the other network layers.

For DQN agents, only the multi-output Q-value function representation supports recurrent neural networks.

criticNetwork = [
    sequenceInputLayer(numObs,'Normalization','none','Name','state')
    fullyConnectedLayer(50, 'Name', 'CriticStateFC1')
    reluLayer('Name','CriticRelu1')
    lstmLayer(20,'OutputMode','sequence','Name','CriticLSTM');
    fullyConnectedLayer(20,'Name','CriticStateFC2')
    reluLayer('Name','CriticRelu2')
    fullyConnectedLayer(numDiscreteAct,'Name','output')];

Create a representation for your critic using the recurrent neural network.

criticOptions = rlRepresentationOptions('LearnRate',1e-3,'GradientThreshold',1);
critic = rlQValueRepresentation(criticNetwork,obsInfo,actInfo,...
    'Observation','state',criticOptions);

Specify options for creating the DQN agent. To use a recurrent neural network, you must specify a SequenceLength greater than 1.

agentOptions = rlDQNAgentOptions(...
    'UseDoubleDQN',false, ...
    'TargetSmoothFactor',5e-3, ...
    'ExperienceBufferLength',1e6, ...
    'SequenceLength',20);
agentOptions.EpsilonGreedyExploration.EpsilonDecay = 1e-4;
agent = rlDQNAgent(critic,agentOptions);

Introduced in R2019a