Main Content

Train Reinforcement Learning Agent for Simple Contextual Bandit Problem

This example shows how to solve a contextual bandit problem [1] using reinforcement learning by training DQN and Q agents. For more information on these agents, see Deep Q-Network (DQN) Agent and Q-Learning Agent.

In bandit problems, the environment has no dynamics (that is, the state is constant and there are no state transitions), so the reward is only influenced by the current action and (for contextual bandits) the (constant) observation. In these problems the observation is also referred to as context. An alternative way to express the previous statements is that in contextual bandit problems an agent selects an action given the initial observation (context), it receives a reward, and the episode terminates.

Since neither rewards nor observations are influenced by previous actions or observations, the environment does not evolve along the time dimension, and there is no sequential decision making. The problem then becomes one of finding the action that maximizes the resulting immediate reward (given a context, if present). Single-armed bandit problems are just special cases of multi-armed bandit problems in which the action is a scalar instead of a vector.

The following figure shows how multi-armed bandits and contextual bandits are special cases of reinforcement learning problems.

Contextual bandits can be used for various applications such as hyperparameter tuning, recommender systems, medical treatment, and 5G communication.

Supervised learning problems can be also be recast as contextual bandit problems. For example, a classification problem can be recast as a contextual bandit problem in which the observation x is an element that needs to be classified as belonging to a specific class, the action a is the agent guess of the class to which x belongs, and the corresponding reward indicates whether the agent's guess is correct or not. Similarly, a regression problem in which a function y=f(x) needs to be approximated by a function a=fˆ(x,θ) of the parameter vector θ, can be recast as contextual bandit problem in which the observation x is an element of the feasible domain of f(x), a is the action (the agent's guess of the value y), and the corresponding reward indicates how close a is to y, (for example, the reward could be -(y-a)2). Note that since the reward (which can only be a scalar) intrinsically contains less information than the true class of x (for the classification case) or than y=f(x) (for the regression case), you can generally expect training time to be considerably longer for the reinforcement learning case than for the corresponding supervised learning case.

Fix Random Seed Generator to Improve Reproducibility

The example code may involve computation of random numbers at various stages such as initialization of the agent, creation of the actor and critic, resetting the environment during simulations, initializing the environment state, generating observations (for stochastic environments), generating exploration actions, and sampling minibatches of experiences for learning. Fixing the random number stream preserves the sequence of the random numbers every time you run the code and improves reproducibility of results. You will fix the random number stream at various locations in the example.

Fix the random number stream with the seed 0 and random number algorithm Mersenne Twister. For more information on random number generation see rng.

previousRngState = rng(0,"twister");

The output previousRngState is a structure that contains information about the previous state of the stream. You will restore the state at the end of the example.

Environment

The contextual bandit environment in this example is defined as follows:

  • Observation (discrete): {1, 2}

The context (initial observation) is sampled randomly.

Pr(s=1)=0.5Pr(s=2)=0.5

  • Action (discrete): {1, 2, 3}

  • Reward:

Rewards in this environment are stochastic. The probability of each observation and action pair is defined as follows.

1.s=1,a=1Pr(r=5   |s=1,a=1)=0.3Pr(r=2|s=1,a=1)=0.72.s=1,a=2Pr(r=10|s=1,a=2)=0.1Pr(r=1  |s=1,a=2)=0.93.s=1,a=3Pr(r=3.5|s=1,a=3)=1

4.s=2,a=1Pr(r=10   |s=2,a=1)=0.2Pr(r=2|s=2,a=1)=0.85.s=2,a=2Pr(r=3|s=2,a=2)=16.s=2,a=3Pr(r=5|s=2,a=3)=0.5Pr(r=0.5|s=2,a=3)=0.5

Note that the agent does not know these distributions.

  • Is-Done signal: Since this is a contextual bandit problem, each episode has only one step. Hence, the Is-Done signal is always 1.

Create Environment Object

The contextual bandit environment is implemented in the file ToyContextualBanditEnvironment, located in this example folder. For more information on how to implement a custom environment using the class template, see Create Custom Environment from Class Template.

Display the environment class. Note how the rewards are calculated in the environment step function, and how the observation (context) remains constant at its initial value.

type("ToyContextualBanditEnvironment.m")
classdef ToyContextualBanditEnvironment < rl.env.MATLABEnvironment
    
    %% Properties (set properties' attributes accordingly)    
    properties
        % Initialize state
        State = zeros(1,1)
    end
    
    properties(Access = protected)
        % Initialize internal flag to indicate episode termination.
        IsDone = false        
    end

    %% Necessary Methods
    methods              
        % Constructor method creates an instance of the environment.
        % Change class name and constructor name accordingly.
        function this = ToyContextualBanditEnvironment()
            % Initialize Observation settings
            % Observation = {s1, s2}, discrete
            obsInfo = rlFiniteSetSpec([1 2]);
            
            % Initialize Action settings
            % Action = {a1, a2, a3}, discrete
            actInfo = rlFiniteSetSpec([1 2 3]);            
            
            % Implement built-in functions of RL env
            this = this@rl.env.MATLABEnvironment(obsInfo,actInfo);
        end
        
        % Apply system dynamics and simulate the environment
        % with the given action for one step.
        function [Observation,Reward,IsDone,aux] = step(this,Action)
            aux = [];

            % The action doesn't affect the next state
            % in a contextual bandit problem.
            Observation = this.State;
            
            % Get reward            
            if this.State == 1
                if Action == 1
                    % E(reward) = 2.9
                    if rand < 0.3
                        Reward = 5;
                    else
                        Reward = 2;
                    end
                elseif Action == 2
                    % E(reward) = 1.9
                    if rand < 0.1
                        Reward = 10;
                    else
                        Reward = 1;
                    end
                elseif Action == 3
                    % E(reward) = 3.5
                    Reward = 3.5;                    
                end
            elseif this.State == 2
               if Action == 1
                   % E(reward) = 3.6
                   if rand < 0.2
                       Reward = 10;
                   else
                       Reward = 2;
                   end
               elseif Action == 2
                   % E(reward) = 3.0
                   Reward = 3.0;               
               elseif Action == 3
                   % E(reward) = 3
                   if rand < 0.5
                        Reward =  5;
                   else
                        Reward = 0.5;
                   end
               end         
            end            
            
            % Get IsDone.
            IsDone = true;
            this.IsDone = IsDone;
            % (Optional) Use notifyEnvUpdated to signal 
            % that the environment has been updated 
            % (e.g. to update visualization).
            notifyEnvUpdated(this);
        end
        
        % Reset environment to initial state 
        % and output initial observation.
        function InitialObservation = reset(this)
            % Pr(s1) = 0.5, Pr(s2) = 0.5
            InitialObservation = randi(2);
            this.State = InitialObservation;

            % (Optional) Use notifyEnvUpdated to signal 
            % that the environment has been updated 
            % (e.g. to update visualization).
            notifyEnvUpdated(this);
        end
    end
    %% Optional Methods (set methods' attributes accordingly)
    methods               
        % (Optional) Visualization method
        function plot(this)            
            % Update the visualization
            envUpdatedCallback(this)
        end
        
        % (Optional) Properties validation through set methods
        function set.State(this,state)
            mustBeMember(state,[1,2])
            this.State = state;
            notifyEnvUpdated(this);
        end
    end    
end

Create the environment object.

env = ToyContextualBanditEnvironment;

Get observation and action specification objects.

obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

Create a DQN Agent

Create a DQN agent option object. For more information, see rlDQNAgentOptions.

agentOpts = rlDQNAgentOptions(...
    UseDoubleDQN = false, ...    
    TargetSmoothFactor = 1, ...
    TargetUpdateFrequency = 4, ...     
    MiniBatchSize = 64,...
    MaxMiniBatchPerEpoch = 2);
agentOpts.EpsilonGreedyExploration.EpsilonDecay = 0.0005;

To create an agent with default network structure, in which each hidden layer has 16 neurons, use rlAgentInitializationOptions.

initOpts = rlAgentInitializationOptions(NumHiddenUnit = 16);

Create a DQN agent. For more information, see rlDQNAgent.

dqnAgent = rlDQNAgent(obsInfo, actInfo, initOpts, agentOpts);

Train the DQN Agent

To train the agent, first specify the training options. For this example, use the following options:

  • Train for 3000 episodes.

  • Since this is a contextual bandit problem, and each episode has only one step, set MaxStepsPerEpisode to 1.

For more information, see rlTrainingOptions.

trainOpts = rlTrainingOptions(...
    MaxEpisodes=3000, ...
    MaxStepsPerEpisode=1, ...
    Verbose=false, ...
    Plots="training-progress",...
    StopTrainingCriteria="None",...
    StopTrainingValue="None"); 

Train the agent using the train function. Training is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining to false. To train the agent yourself, set doTraining to true.

doTraining = false;
if doTraining
    % Train the agent
    trainingStats = train(dqnAgent,env,trainOpts);
else
    % Load the pretrained agent for the example
    load("ToyContextualBanditDQNAgent.mat","dqnAgent")
end

Validate the DQN Agent

For this example, you know the distribution of the rewards, and you can compute the optimal actions. Validate the agent's performance by comparing these optimal actions with the actions selected by the agent.

First, compute the true expected rewards using their true probability distribution.

1. The expected reward of each action at s=1 is as follows.

Ifa=1E[R]=0.3*5+0.7*2=2.9Ifa=2E[R]=0.1*10+0.9*1=1.9Ifa=3E[R]=3.5

Hence, the optimal action is 3 when s=1.

2. The expected reward of each action at s=2 is as follows.

Ifa=1E[R]=0.2*10+0.8*2=3.6Ifa=2E[R]=3.0If  a=3E[R]=0.5*5+0.5*0.5=2.75

Hence, the optimal action is 1 when s=2.

With enough sampling, the Q-value estimates of the trained agent should be closer to the true expected reward.

Collect the true expected rewards in the ExpectedRewards variable.

ExpectedRewards = zeros(2,3);
ExpectedRewards(1,1) = 0.3*5 + 0.7*2;
ExpectedRewards(1,2) = 0.1*10 + 0.9*1;
ExpectedRewards(1,3) = 3.5;
ExpectedRewards(2,1) = 0.2*10 + 0.8*2;
ExpectedRewards(2,2) = 3.0;
ExpectedRewards(2,3) = 0.5*5 + 0.5*0.5;

Visualize the true expected rewards using the function localPlotQvalues defined at the end of the example.

localPlotQvalues(ExpectedRewards,"True Expected Rewards")

Figure contains an axes object. The axes object with title True Expected Rewards contains 7 objects of type image, text.

Now, validate whether the DQN agent learns the optimal behavior. Use getActionInfo to return the agent action given an input observation.

If the state is 1, the optimal action is 3.

observation = 1;
getAction(dqnAgent,observation)
ans = 1x1 cell array
    {[3]}

The agent selects the optimal action.

If the state is 2, the optimal action is 1.

observation = 2;
getAction(dqnAgent,observation)
ans = 1x1 cell array
    {[1]}

The agent selects the optimal action. Thus, the DQN agent has learned the optimal behavior.

Next, compare the Q-Value function to the true expected reward when selecting the optimal action.

Use getCritic to extract the critic from the trained agent, and getValue to return the value of an observation (using the learned value function).

dqnCritic = getCritic(dqnAgent);
qValues = zeros(2,3);
for s = 1:2
    qValues(s,:) = getValue(dqnCritic, {s});
end

Visualize the Q values for the DQN agent.

figure(1)
localPlotQvalues(qValues, "Q values of DQN agent")

Figure contains an axes object. The axes object with title Q values of DQN agent contains 7 objects of type image, text.

The learned Q-values are close to the true expected rewards computed above.

Create a Q-Learning Agent

Ensure reproducibility by fixing the random seed generation.

rng(0,"twister")

For this example, use a table as approximation model for the critic.

Create a table using the observation and action specifications from the environment.

qTable = rlTable(obsInfo, actInfo);

Create an rlQValueFunction critic.

critic = rlQValueFunction(qTable, obsInfo, actInfo);

To set the agent exploration options, create an rlQAgentOptions object

opt = rlQAgentOptions;
opt.EpsilonGreedyExploration.Epsilon = 1;
opt.EpsilonGreedyExploration.EpsilonDecay = 0.0005;

Create a Q agent. For more information, see rlQAgent.

qAgent = rlQAgent(critic,opt);

Train the Q-Learning Agent

To save time while running this example, load a pretrained agent by setting doTraining to false. To train the agent yourself, set doTraining to true.

doTraining = false;
if doTraining
    % Train the agent.
    trainingStats = train(qAgent,env,trainOpts);
else
    % Load the pretrained agent for the example.
    load("ToyContextualBanditQAgent.mat","qAgent")
end

Validate the Q-Learning Agent

When the state is 1, the optimal action is 3.

observation = 1;
getAction(qAgent,observation)
ans = 1x1 cell array
    {[3]}

The agent selects the optimal action.

When the state is 2, the optimal action is 1.

observation = 2;
getAction(qAgent,observation)
ans = 1x1 cell array
    {[1]}

The agent selects the optimal action.

Next, compare the Q-Value function to the true expected reward when selecting the optimal action.

Use getCritic to extract the critic from the trained agent, and getValue to return the value of an observation (using the learned value function).

figure(2)
qCritic = getCritic(qAgent);
qValues = zeros(2,3);
for s = 1:2
    for a = 1:3
        qValues(s,a) = getValue(qCritic, {s}, {a});
    end
end

Visualize the Q values for the DQN agent.

localPlotQvalues(qValues, "Q values of Q agent")

Figure contains an axes object. The axes object with title Q values of Q agent contains 7 objects of type image, text.

The learned Q-values are close to the true expected rewards. The Q-values for deterministic rewards, Q(s=1, a=3) and Q(s=2, a=2), are the same as the true expected rewards.

Note that the corresponding Q-values learned by the DQN network, while close, are not identical to the true values. This happens because the DQN uses a neural network instead of a table as function approximation model.

Restore the random number stream using the information stored in previousRngState.

rng(previousRngState);

Local Function

function localPlotQvalues(QValues, titleText)
    % Visualize Q values 
    figure;
    imagesc(QValues,[1,4])
    colormap("autumn")
    title(titleText)
    colorbar
    set(gca,"Xtick",1:3,"XTickLabel",{"a=1", "a=2", "a=3"})
    set(gca,"Ytick",1:2,"YTickLabel",{"s=1", "s=2"})

    % Plot values on the image
    x = repmat(1:size(QValues,2), size(QValues,1), 1);
    y = repmat(1:size(QValues,1), size(QValues,2), 1)';
    QValuesStr = num2cell(QValues);
    QValuesStr = cellfun(@num2str, QValuesStr, UniformOutput=false);
    text(x(:), y(:), QValuesStr, HorizontalAlignment = "Center")
end

Reference

[1] Sutton, Richard S., and Andrew G. Barto. Reinforcement Learning: An Introduction. Second edition. Adaptive Computation and Machine Learning Series. Cambridge, Massachusetts: The MIT Press, 2018.

See Also

Apps

Functions

Objects

Related Examples

More About

Go to top of page