MATLAB Answers

RL Toolbox: Proximal Policy Optimisation

32 views (last 30 days)
Robert Gordon
Robert Gordon on 8 Aug 2019
Commented: Weihao Yuan on 22 Aug 2020
I just wanted to ask if anyone is aware of a proximal policy optimisation (PPO) reinforement learning implimentation avaliable for MATLAB RL Toolbox. I know that you can create a custom agent class, but I wanted to see if anyone else has implimented it before?
Thanks!

  0 Comments

Sign in to comment.

Answers (1)

Emmanouil Tzorakoleftherakis
Hi Robert,
Reinforcement Learning Toolbox in R2019b has a PPO implementation for discrete action spaces. Future releases will include continuous action spaces as well.
I hope this helps.

  6 Comments

Show 3 older comments
Emmanouil Tzorakoleftherakis
Hello,
Change your critic to something like the following and training should work:
statePath = [
imageInputLayer([numObservations 1 1],'Normalization','none','Name','S')
fullyConnectedLayer(200,'Name','CriticStateFC1') %128
reluLayer('Name','CriticRelu1')
fullyConnectedLayer(300,'Name','CriticStateFC2') %200
reluLayer('Name','CriticRelu2')
fullyConnectedLayer(400,'Name','CriticStateFC3')
reluLayer('Name','CriticCommonRelu')
fullyConnectedLayer(1,'Name','CriticPreOutput')]; %200
criticNetwork = layerGraph(statePath);
There were two issues previously: a) PPO uses a V network for critic but you created what essentially was a Q network (you had an action input channel) and b) you don't need a regressionlayer at the end of the critic. This was only to pass the analysis test in Deep Network Designer.
Camilo Manrique
Camilo Manrique on 26 Mar 2020
It worked indeed, you are right, I completely forgot about the fact that PPO uses a value function approach, instead of to the Q-value function used with DDPG. Thank you very much for your help.
Weihao Yuan
Weihao Yuan on 22 Aug 2020
Hi Emmanouil, I encountered a similar problem when applying PPO to the ACC model in DDPG example.
Environment
mdl = 'rlACCMdl';
open_system(mdl)
agentblk = [mdl '/RL Agent'];
% create the observation info
observationInfo = rlNumericSpec([3 1],'LowerLimit',-inf*ones(3,1),'UpperLimit',inf*ones(3,1));
observationInfo.Name = 'observations';
observationInfo.Description = 'information on velocity error and ego velocity';
% action Info
actionInfo = rlNumericSpec([1 1],'LowerLimit',-3,'UpperLimit',2);
actionInfo.Name = 'acceleration';
% define environment
env = rlSimulinkEnv(mdl,agentblk,observationInfo,actionInfo);
Critic
predefinedWeightsandBiases = false;
if predefinedWeightsandBiases
load('PredefinedWeightsAndBiases.mat');
else
createNetworkWeights;
end
criticNetwork = [imageInputLayer([numObs 1 1],'Normalization','none','Name','observation')
fullyConnectedLayer(200,'Name','CriticFC1', ...
'Weights',weights.criticFC1, ...
'Bias',bias.criticFC1)
reluLayer('Name','CriticRelu1')
fullyConnectedLayer(100,'Name','CriticFC2', ...
'Weights',weights.criticFC2, ...
'Bias',bias.criticFC2)
reluLayer('Name','CriticRelu2')
fullyConnectedLayer(1,'Name','CriticOutput',...
'Weights',weights.criticOut,...
'Bias',bias.criticOut)];
criticOptions = rlRepresentationOptions('LearnRate',1e-3,'GradientThreshold',1,'L2RegularizationFactor',1e-4);
critic = rlValueRepresentation(criticNetwork,observationInfo,...
'Observation',{'observation'},criticOptions);
Actor
% observation path layers (3 by 1 input and a 2 by 1 output)
actorNetwork = [ imageInputLayer([3 1 1], 'Normalization','none','Name','observation')
fullyConnectedLayer(2,'Name','infc') ];
% path layers for mean value (2 by 1 input and 2 by 1 output)
% using scalingLayer to scale the range
meanPath = [ tanhLayer('Name','tanh');
scalingLayer('Name','ActorScaling','Scale',2.5,'Bias',-0.5)];
% path layers for variance (2 by 1 input and output)
% using softplus layer to make it non negative)
variancePath = softplusLayer('Name', 'Softplus');
% conctatenate two inputs (along dimension #3) to form a single (4 by 1) output layer
outLayer = concatenationLayer(3,2,'Name','gaussPars');
% add layers to network object
net = layerGraph(actorNetwork);
net = addLayers(net,meanPath);
net = addLayers(net,variancePath);
net = addLayers(net,outLayer);
% connect layers
net = connectLayers(net,'infc','tanh/in'); % connect output of inPath to meanPath input
net = connectLayers(net,'infc','Softplus/in'); % connect output of inPath to variancePath input
net = connectLayers(net,'ActorScaling','gaussPars/in1'); % connect output of meanPath to gaussPars input #1
net = connectLayers(net,'Softplus','gaussPars/in2'); % connect output of variancePath to gaussPars input #2
% plot network
plot(net)
However the agent stopped training at 50th episode:
Error
Error using rl.env.AbstractEnv/simWithPolicy (line 70)
An error occurred while simulating "rlACCMdl" with the agent "agent".
Error in rl.task.SeriesTrainTask/runImpl (line 33)
[varargout{1},varargout{2}] = simWithPolicy(this.Env,this.Agent,simOpts);
Error in rl.task.Task/run (line 21)
[varargout{1:nargout}] = runImpl(this);
Error in rl.task.TaskSpec/internal_run (line 159)
[varargout{1:nargout}] = run(task);
Error in rl.task.TaskSpec/runDirect (line 163)
[this.Outputs{1:getNumOutputs(this)}] = internal_run(this);
Error in rl.task.TaskSpec/runScalarTask (line 187)
runDirect(this);
Error in rl.task.TaskSpec/run (line 69)
runScalarTask(task);
Error in rl.train.SeriesTrainer/run (line 24)
run(seriestaskspec);
Error in rl.train.TrainingManager/train (line 291)
run(trainer);
Error in rl.train.TrainingManager/run (line 160)
train(this);
Error in rl.agent.AbstractAgent/train (line 54)
TrainingStatistics = run(trainMgr);
Caused by:
Error using rl.env.SimulinkEnvWithAgent>localHandleSimoutErrors (line 689)
Invalid input argument type or size such as observation, reward, isdone or loggedSignals.
Error using rl.env.SimulinkEnvWithAgent>localHandleSimoutErrors (line 689)
Standard deviation must be nonnegative. Ensure your representation always outputs nonnegative values for outputs that correspond to the standard deviation.
I tried to find the reason of this bug but failed. I would be really appreciated if you could check this bug for me. Thanks a lot.

Sign in to comment.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!