Why does my custom SAC agent behave differently from built-in SAC agent
    15 views (last 30 days)
  
       Show older comments
    
I implemented one custom SAC agent, which I have to, with MATLAB deep learning automatic differentiation. However, when compared to MATLAB built-in SAC agent on a certain task with exactly the same hyperparameters, the custom SAC agent failed to complete the task while the built-in agent succeeded.
Here is the training process of the built-in agent:

This is the training progress of the custom SAC agent(alongwith loss):

Here are the codes for the custom SAC agent and training:
1.Implementation of custom SAC agent
classdef MySACAgent < rl.agent.CustomAgent
    properties
        %networks
        actor
        critic1
        critic2
        critic_target1
        critic_target2
        log_alpha%entropy weight(log transformed)
        %training options
        options%Agent options
        %optimizers
        actorOptimizer
        criticOptimizer_1
        criticOptimizer_2
        entWgtOptimizer
        %experience buffers
        obsBuffer
        actionBuffer
        rewardBuffer
        nextObsBuffer
        isDoneBuffer
        rlExpBuffer
        bufferIdx
        bufferLen
        %loss to record
        cLoss
        aLoss
        eLoss
    end
    properties(Access = private)
        Ts
        counter
        numObs
        numAct
    end
    methods
        %constructor
        function obj = MySACAgent(numObs,numAct,obsInfo,actInfo,hid_dim,Ts,options)
            % options' field:MaxBufferLen WarmUpSteps MiniBatchSize 
            % LearningFrequency EntropyLossWeight DiscountFactor
            % OptimizerOptions(cell) PolicyUpdateFrequency TargetEntropy
            % TargetUpdateFrequency TargetSmoothFactor
            % base_seed NumGradientStepsPerUpdate
            %OptimizerOptions(for actor&critic)
            % (required) Call the abstract class constructor.
            rng(options.base_seed);%set random seed
            obj = obj@rl.agent.CustomAgent();
            obj.ObservationInfo = obsInfo;
            obj.ActionInfo = actInfo;
            % obj.SampleTime = Ts;%explicitly assigned for simulink
            obj.Ts = Ts;
            %create networks
            if isempty(hid_dim)
                hid_dim = 256;
            end
            obj.actor = CreateActor(obj,numObs,numAct,hid_dim,obsInfo,actInfo);
            [obj.critic1,obj.critic2,obj.critic_target1,obj.critic_target2] = CreateCritic(obj,numObs,numAct,hid_dim,obsInfo,actInfo);
            obj.options = options;
            assert(options.WarmUpSteps>options.MiniBatchSize,...
                'options.WarmUpSteps must not be less than options.MiniBatchSize');
            %set optimizers
            obj.actorOptimizer = rlOptimizer(options.OptimizerOptions{1});
            obj.criticOptimizer_1 = rlOptimizer(options.OptimizerOptions{2});
            obj.criticOptimizer_2 = rlOptimizer(options.OptimizerOptions{3});
            obj.entWgtOptimizer = rlOptimizer(options.OptimizerOptions{4});
            obj.cLoss=0;
            obj.aLoss=0;
            obj.eLoss=0;
            % (optional) Cache the number of observations and actions.
            obj.numObs = numObs;
            obj.numAct = numAct;
            % (optional) Initialize buffer and counter.
            resetImpl(obj);
            % obj.rlExpBuffer = rlReplayMemory(obsInfo,actInfo,options.MaxBufferLen);
        end
        function resetImpl(obj)
            % (Optional) Define how the agent is reset before training/
            resetBuffer(obj);
            obj.counter = 0;
            obj.bufferLen=0;
            obj.bufferIdx = 0;%base 0
            obj.log_alpha = dlarray(log(obj.options.EntropyLossWeight));
        end
        function resetBuffer(obj)
            % Reinitialize observation buffer. Allocate as dlarray to
            % support automatic differentiation with dlfeval and
            % dlgradient.
            %format:CBT
            obj.obsBuffer = dlarray(...
                zeros(obj.numObs,obj.options.MaxBufferLen),'CB');
            % Reinitialize action buffer with valid actions.
            obj.actionBuffer = dlarray(...
                zeros(obj.numAct,obj.options.MaxBufferLen),'CB');
            % Reinitialize reward buffer.
            obj.rewardBuffer = dlarray(zeros(1,obj.options.MaxBufferLen),'CB');
            % Reinitialize nextState buffer.
            obj.nextObsBuffer = dlarray(...
                zeros(obj.numObs,obj.options.MaxBufferLen),'CB');
            % Reinitialize mask buffer.
            obj.isDoneBuffer = dlarray(zeros(1,obj.options.MaxBufferLen),'CB');
        end
        %Create networks
        %Actor
        function actor = CreateActor(obj,numObs,numAct,hid_dim,obsInfo,actInfo)
            % Create the actor network layers.
            commonPath = [
                featureInputLayer(numObs,Name="obsInLyr")
                fullyConnectedLayer(hid_dim)
                layerNormalizationLayer
                reluLayer
                fullyConnectedLayer(hid_dim)
                layerNormalizationLayer
                reluLayer(Name="comPathOutLyr")
                ];
            meanPath = [
                fullyConnectedLayer(numAct,Name="meanOutLyr")
                ];
            stdPath = [
                fullyConnectedLayer(numAct,Name="stdInLyr")
                softplusLayer(Name="stdOutLyr")
                ];
            % Connect the layers.
            actorNetwork = layerGraph(commonPath);
            actorNetwork = addLayers(actorNetwork,meanPath);
            actorNetwork = addLayers(actorNetwork,stdPath);
            actorNetwork = connectLayers(actorNetwork,"comPathOutLyr","meanOutLyr/in");
            actorNetwork = connectLayers(actorNetwork,"comPathOutLyr","stdInLyr/in");
            actordlnet = dlnetwork(actorNetwork);
            actor = initialize(actordlnet);
        end
        %Critic
        function [critic1,critic2,critic_target1,critic_target2] = CreateCritic(obj,numObs,numAct,hid_dim,obsInfo,actInfo)
            % Define the network layers.
            criticNet = [
                featureInputLayer(numObs+numAct,Name="obsInLyr")%input:[obs act]
                fullyConnectedLayer(hid_dim)
                layerNormalizationLayer
                reluLayer
                fullyConnectedLayer(hid_dim)
                layerNormalizationLayer
                reluLayer
                fullyConnectedLayer(1,Name="QValueOutLyr")
                ];
            % Connect the layers.
            criticNet = layerGraph(criticNet);
            criticDLnet = dlnetwork(criticNet,'Initialize',false);
            critic1 = initialize(criticDLnet);
            critic2 = initialize(criticDLnet);%c1 and c2 different initilization
            critic_target1 = initialize(criticDLnet);
            critic_target1.Learnables = critic1.Learnables;
            critic_target1.State = critic1.State;
            critic_target2 = initialize(criticDLnet);
            critic_target2.Learnables = critic2.Learnables;
            critic_target2.State = critic2.State;
        end
        function logP = logProbBoundedAction(obj,boundedAction,mu,sigma)
            %used to calculate log probability for tanh(gaussian)
            %validated, nothing wrong with this function
            eps=1e-10;
            logP = sum(log(1/sqrt(2*pi)./sigma.*exp(-0.5*(0.5*...
                log((1+boundedAction+eps)./(1-boundedAction+eps))-mu).^2./sigma.^2).*1./(1-boundedAction.^2+eps)),1);
        end
        %loss functions
        function [vLoss_1, vLoss_2, criticGrad_1, criticGrad_2] = criticLoss(obj,batchExperiences,c1,c2)
            batchObs = batchExperiences{1};
            batchAction = batchExperiences{2};
            batchReward = batchExperiences{3};
            batchNextObs = batchExperiences{4};
            batchIsDone = batchExperiences{5};
            batchSize = size(batchObs,2);
            gamma = obj.options.DiscountFactor;
            y = dlarray(zeros(1,batchSize));%CB(C=1)
            y = y + batchReward;
            actionNext = getActionWithExploration_dlarray(obj,batchNextObs);%CB
            actionNext = actionNext{1};
            Qt1=predict(obj.critic_target1,cat(1,batchNextObs,actionNext));%CB(C=1)
            Qt2=predict(obj.critic_target2,cat(1,batchNextObs,actionNext));%CB(C=1)
            [mu,sigma] = predict(obj.actor,batchNextObs);%CB:numAct*batch
            next_action = tanh(mu + sigma.*randn(size(sigma)));
            logP = logProbBoundedAction(obj,next_action,mu,sigma);
            y = y + (1 - batchIsDone).*(gamma*(min(cat(1,Qt1,Qt2),[],1) - exp(obj.log_alpha)*logP));
            critic_input = cat(1,batchObs,batchAction);
            Q1 = forward(c1,critic_input);
            Q2 = forward(c2,critic_input);
            vLoss_1 = 1/2*mean((y - Q1).^2,'all');
            vLoss_2 = 1/2*mean((y - Q2).^2,'all');
            criticGrad_1 = dlgradient(vLoss_1,c1.Learnables);
            criticGrad_2 = dlgradient(vLoss_2,c2.Learnables);
        end
        function [aLoss,actorGrad] = actorLoss(obj,batchExperiences,actor)
            batchObs = batchExperiences{1};
            batchSize = size(batchObs,2);
            [mu,sigma] = forward(actor,batchObs);%CB:numAct*batch
            curr_action = tanh(mu + sigma.*randn(size(sigma)));%reparameterization
            critic_input = cat(1,batchObs,curr_action);
            Q1=forward(obj.critic1,critic_input);%CB(C=1)
            Q2=forward(obj.critic2,critic_input);%CB(C=1)
            logP = logProbBoundedAction(obj,curr_action,mu,sigma);
            aLoss = mean(-min(cat(1,Q1,Q2),[],1) + exp(obj.log_alpha) * logP,'all');
            actorGrad= dlgradient(aLoss,actor.Learnables);
        end
        function [eLoss,entGrad] = entropyLoss(obj,batchExperiences,logAlpha)
            batchObs = batchExperiences{1};
            [mu,sigma] = predict(obj.actor,batchObs);%CB:numAct*batch
            curr_action = tanh(mu + sigma.*randn(size(sigma)));
            ent = mean(-logProbBoundedAction(obj,curr_action,mu,sigma));
            eLoss = exp(logAlpha) * (ent - obj.options.TargetEntropy);
            entGrad = dlgradient(eLoss,logAlpha);
        end
    end
    methods(Access=protected)
        %return SampleTime
        function ts = getSampleTime_(obj)
            ts = obj.Ts;
        end
        %get action without exploration
        function action = getActionImpl(obj,obs)
            %obs:dlarray CB
            if ~isa(obs,'dlarray')
                if isa(obs,'cell')
                    obs = dlarray(obs{1},'CB');
                else
                    obs = dlarray(obs,'CB');
                end
            end
            [mu,~] = predict(obj.actor,obs);
            mu = extractdata(mu);
            action = {tanh(mu)};
        end
        %get action with exploration
        function action = getActionWithExplorationImpl(obj,obs)
            %obs:dlarray CT
            if ~isa(obs,'dlarray') || size(obs,1)~=obj.numObs
                obs = dlarray(randn(obj.numObs,1),'CB');
            end
            [mu,sigma] = predict(obj.actor,obs);
            mu = extractdata(mu);
            sigma = extractdata(sigma);
            action = {tanh(mu + sigma .* randn(size(sigma)))};
        end
        function action = getActionWithExploration_dlarray(obj,obs)
            [mu,sigma] = predict(obj.actor,obs);
            action = {tanh(mu + sigma .* randn(size(sigma)))};
        end
        %learning
        function action = learnImpl(obj,Experience)
            % Extract data from experience.
            obs = Experience{1};
            action = Experience{2};
            reward = Experience{3};
            nextObs = Experience{4};
            isDone = logical(Experience{5});
            obj.obsBuffer(:,obj.bufferIdx+1,:) = obs{1};
            obj.actionBuffer(:,obj.bufferIdx+1,:) = action{1};
            obj.rewardBuffer(:,obj.bufferIdx+1) = reward;
            obj.nextObsBuffer(:,obj.bufferIdx+1,:) = nextObs{1};
            obj.isDoneBuffer(:,obj.bufferIdx+1) = isDone;
            obj.bufferLen = max(obj.bufferLen,obj.bufferIdx+1);
            obj.bufferIdx = mod(obj.bufferIdx+1,obj.options.MaxBufferLen);
            if obj.bufferLen>=max(obj.options.WarmUpSteps,obj.options.MiniBatchSize)
                obj.counter = obj.counter + 1;
                if (obj.options.LearningFrequency==-1 && isDone) || ...
                        (obj.options.LearningFrequency>0 && mod(obj.counter,obj.options.LearningFrequency)==0)
                    for gstep = 1:obj.options.NumGradientStepsPerUpdate
                        %sample batch
                        batchSize = obj.options.MiniBatchSize;
                        batchInd = randperm(obj.bufferLen,batchSize);
                        batchExperience = {
                            obj.obsBuffer(:,batchInd,:),...
                            obj.actionBuffer(:,batchInd,:),...
                            obj.rewardBuffer(:,batchInd),...
                            obj.nextObsBuffer(:,batchInd,:),...
                            obj.isDoneBuffer(:,batchInd)
                            };
                        %update the parameters of each critic
                        [cLoss1,cLoss2,criticGrad_1,criticGrad_2] = dlfeval(@(x,c1,c2)obj.criticLoss(x,c1,c2),batchExperience,obj.critic1,obj.critic2);
                        obj.cLoss = min(extractdata(cLoss1),extractdata(cLoss2));
                        [obj.critic1.Learnables.Value,obj.criticOptimizer_1] = update(obj.criticOptimizer_1,obj.critic1.Learnables.Value,criticGrad_1.Value);
                        [obj.critic2.Learnables.Value,obj.criticOptimizer_2] = update(obj.criticOptimizer_2,obj.critic2.Learnables.Value,criticGrad_2.Value);
                        if (mod(obj.counter,obj.options.PolicyUpdateFrequency)==0 && obj.options.LearningFrequency==-1) ||...
                                (mod(obj.counter,obj.options.LearningFrequency * obj.options.PolicyUpdateFrequency)==0 ...
                                && obj.options.LearningFrequency>0)
                            %update the parameters of actor
                            [aloss,actorGrad] = dlfeval(...
                                @(x,actor)obj.actorLoss(x,actor),...
                                batchExperience,obj.actor);
                            obj.aLoss = extractdata(aloss);
                            [obj.actor.Learnables.Value,obj.actorOptimizer] = update(obj.actorOptimizer,obj.actor.Learnables.Value,actorGrad.Value);
                            %update the entropy weight
                            [eloss,entGrad] = dlfeval(@(x,alpha)obj.entropyLoss(x,alpha),batchExperience,obj.log_alpha);
                            obj.eLoss = extractdata(eloss);
                            % disp(obj.alpha)
                            [obj.log_alpha,obj.entWgtOptimizer] = update(obj.entWgtOptimizer,{obj.log_alpha},{entGrad});
                            obj.log_alpha = obj.log_alpha{1};
                        end
                        %update critic targets
                        %1
                        critic1_params = obj.critic1.Learnables.Value;%cell array network params
                        critic_target1_params = obj.critic_target1.Learnables.Value;
                        for i=1:size(critic1_params,1)
                            obj.critic_target1.Learnables.Value{i} = obj.options.TargetSmoothFactor * critic1_params{i}...
                                + (1 - obj.options.TargetSmoothFactor) * critic_target1_params{i};
                        end
                        %2
                        critic2_params = obj.critic2.Learnables.Value;%cell array network params
                        critic_target2_params = obj.critic_target2.Learnables.Value;
                        for i=1:size(critic2_params,1)
                            obj.critic_target2.Learnables.Value{i} = obj.options.TargetSmoothFactor * critic2_params{i}...
                                + (1 - obj.options.TargetSmoothFactor) * critic_target2_params{i};
                        end
                        % end
                    end
                end
            end
            action = getActionWithExplorationImpl(obj,nextObs{1});
        end
    end
end
2.Configuration of 'options' property(same as those used for the built-in SAC agent)
    options.MaxBufferLen = 1e4;
    options.WarmUpSteps = 1000;
    options.MiniBatchSize = 256;
    options.LearningFrequency = -1;%when -1: train after each episode
    options.EntropyLossWeight = 1;
    options.DiscountFactor = 0.99;
    options.PolicyUpdateFrequency = 1;
    options.TargetEntropy = -2;
    options.TargetUpdateFrequency = 1;
    options.TargetSmoothFactor = 1e-3;
    options.NumGradientStepsPerUpdate = 10;
    %optimizerOptions: actor critic1 critic2 entWgt(alpha) 
    %encoder decoder
    options.OptimizerOptions = {
        rlOptimizerOptions("Algorithm","adam","GradientThreshold",1,'LearnRate',1e-3),...
        rlOptimizerOptions("Algorithm","adam","GradientThreshold",1,'LearnRate',1e-3),...
        rlOptimizerOptions("Algorithm","adam","GradientThreshold",1,'LearnRate',1e-3),...
        rlOptimizerOptions("Algorithm","adam",'LearnRate',3e-4),...
        rlOptimizerOptions("Algorithm","adam","GradientThreshold",1,'LearnRate',1e-3),...
        rlOptimizerOptions("Algorithm","adam","GradientThreshold",1,'LearnRate',1e-3)};
    options.base_seed=940;
3.training
clc;
clear;
close all;
run('init_car_params.m');
%create RL env
numObs = 4;  % vx vy r beta_user
numAct = 2;   % st_angle_ref rw_omega_ref
obsInfo = rlNumericSpec([numObs 1]);
actInfo = rlNumericSpec([numAct 1]);
actInfo.LowerLimit = -1;
actInfo.UpperLimit = 1;
mdl = "prius_sm_model";
blk = mdl + "/RL Agent";
env = rlSimulinkEnv(mdl,blk,obsInfo,actInfo);
params=struct('rw_radius',rw_radius,'a',a,'b',b,'init_vx',init_vx,'init_yaw_rate',init_yaw_rate);
env.ResetFcn = @(in) PriusResetFcn(in,params,mdl);
Ts = 1/10;
Tf = 5;
%create actor
rnd_seed=940;
algorithm = 'MySAC';
switch algorithm
    case 'SAC'
        agent = createNetworks(rnd_seed,numObs,numAct,obsInfo,actInfo,Ts);
    case 'MySAC'
        hid_dim = 256;
        options=getDWMLAgentOptions();
        agent = MySACAgent(numObs,numAct,obsInfo,actInfo,hid_dim,Ts,options);
end
%%
%train agent
close all
maxEpisodes = 6000;
maxSteps = floor(Tf/Ts);
useParallel = false;
run_idx=9;
saveAgentDir = ['savedAgents/',algorithm,'/',num2str(run_idx)];
switch algorithm
    case 'SAC'
        trainOpts = rlTrainingOptions(...
            MaxEpisodes=maxEpisodes, ...
            MaxStepsPerEpisode=maxSteps, ...
            ScoreAveragingWindowLength=100, ...  
            Plots="training-progress", ...
            StopTrainingCriteria="AverageReward", ...
            UseParallel=useParallel,...
            SaveAgentCriteria='EpisodeReward',...
            SaveAgentValue=35,...
            SaveAgentDirectory=saveAgentDir);
            % SaveAgentCriteria='EpisodeFrequency',...
            % SaveAgentValue=1,...
    case 'MySAC'
        trainOpts = rlTrainingOptions(...
            MaxEpisodes=maxEpisodes, ...
            MaxStepsPerEpisode=maxSteps, ...
            ScoreAveragingWindowLength=100, ...  
            Plots="training-progress", ...
            StopTrainingCriteria="AverageReward", ...
            UseParallel=useParallel,...
            SaveAgentCriteria='EpisodeReward',...
            SaveAgentValue=35,...
            SaveAgentDirectory=saveAgentDir);
end
set_param(mdl,"FastRestart","off");%for random initialization
if trainOpts.UseParallel
    % Disable visualization in Simscape Mechanics Explorer
    set_param(mdl, SimMechanicsOpenEditorOnUpdate="off");
    save_system(mdl);
else
    % Enable visualization in Simscape Mechanics Explorer
    set_param(mdl, SimMechanicsOpenEditorOnUpdate="on");
end
%load training data
monitor = trainingProgressMonitor();
logger = rlDataLogger(monitor);
logger.EpisodeFinishedFcn = @myEpisodeLoggingFcn;
doTraining = true;
if doTraining
    trainResult = train(agent,env,trainOpts,Logger=logger); 
end
% %logger callback used for MySACAgent
function dataToLog = myEpisodeLoggingFcn(data)
    dataToLog.criticLoss = data.Agent.cLoss;
    dataToLog.actorLoss = data.Agent.aLoss;
    dataToLog.entLoss = data.Agent.eLoss;
    % dataToLog.denoiseLoss = data.Agent.dnLoss;
end
In the simulink environment used, action output by the Agent block(in [-1,1]) is denormalized and fed into the environment.
I think possible causes of the problem include:
1.Wrong implementation of critic loss. As shown in the training progress, critic loss seemed to diverge. It's hardly caused by hyperparameters(batch size or learning rate or target update frequency) because they worked well for the built-in agent. So it is more likely the critic loss is wrong.
2.Wrong implementation of replay buffer. I implemented the replay buffer as a circular queue, where I sampled uniformly to get batch training data. From the comparison of the training progress shown above, the custom SAC agent did explore states with high reward(around 30) but failed to exploit them, So I guess there is still problem with my replay buffer. 
3.Gradient flow was broken.The learning is done with the help of MATLAB deep learning automatic differentiation. Perhaps some of my implementation violates the computational rule of automatic differentiation, which broke the gradient flow during forward computation or backpropagation and led to wrong result.
4.Gradient step(update frequency). In current implementation, NumGradientStepsPerUpdate gradient steps are executed after each episode. During each gradient step, cirtic(s) and actor, alongwith entropy weight, is updated once. I am not sure whether the current implementation of gradient step has got the update frequency right.
5.Also could be normalization problem, but I am not so sure.
I plan to debug 3 first.
Please read the code and help find potential causes of the gap between the custom SAC agent and the built-in one.
Finally, I am actually trying to extend SAC algorithm to a more complex framework. I didn't choose to inherit the built-in SAC agent(rlSACAgent), would it be recommended to do my development by doing so?
2 Comments
  Vincent
 on 15 Aug 2025
				For the simulink model, did you use a rate transition block between your custom agent and states,rewards,done signal? Because I wrote a Custom TD3 and I had to use the rate transition block for the simulink model. But my TD3 agent is bad so I don't know if it's my code or the rate transition block.
Thanks
Accepted Answer
  Kaustab Pal
 on 29 Aug 2024
        Hi @一凡
Upon reviewing your critic loss implementation, I'd like to offer some insights. 
1.    While the overall structure appears sound, there might be subtle dimension mismatches that could affect performance. It's crucial to ensure all operations are elementwise and that dimensions align perfectly across your tensors. 
- You can pay particular attention to the “batchIsDone” variable - verify it's being broadcasted correctly to match other tensors' dimensions.
- Also consider using MATLAB's “bsxfun” or broadcasting syntax to guarantee proper dimension handling. These small details can significantly impact the stability and effectiveness of your learning process.
- To further diagnose potential issues, I recommend using MATLAB's debugging tools to visualize tensor shapes at each step of the loss calculation. This can help pinpoint any unexpected dimension conflicts.
- Additionally, consider adding “assert” statements to check tensor dimensions explicitly, which can catch issues early in the training process.
2.    Your circular queue implementation seems reasonable. However, ensure that you're not overwriting experiences too quickly.
- Consider increasing the buffer size or adjusting the sampling strategy. You might want to implement prioritized experience replay to focus on more important transitions.
3.    Regarding Gradient Flow, ensure that all operations in the forward pass are differentiable functions.
4.    In your current implementation, you are updating the networks after each episode. This might lead to instability. Consider updating after every N steps. Also ensure that the target networks are being updated correctly and at the right frequency.
5.    Lastly, often normalizing the rewards leads to a more stable training. You can try that out as well.
Regarding your final question about inheriting from “rlSACAgent”: If you're planning to extend the SAC algorithm significantly, creating a custom implementation as you've done gives you more flexibility. However, if your extensions are minor, inheriting from “rlSACAgent” could save you time and reduce the chances of implementation errors.
Please refer to the following official documentations to learn about the functions in more detail: 
- rlSACAgent: https://www.mathworks.com/help/reinforcement-learning/ref/rl.agent.rlsacagent.html
- bsxfun: https://www.mathworks.com/help/matlab/ref/bsxfun.html
- assert: https://www.mathworks.com/help/sltest/ref/assert.html
I hope these suggestions resolves your query.
More Answers (1)
  一凡
 on 2 Sep 2024
        5 Comments
  Wenxuan
 on 21 Aug 2025
				Hi @Vincent
I am not certain about the specific issue you encountered with the TD3 algorithm, but I can share some experience in tuning custom agents. First, it is essential to ensure correctness in the implementation, including both the reinforcement learning principles and the accuracy of information transfer during execution. In my experience, when custom agents and MATLAB’s built-in agents use identical parameter settings, the performance of the custom implementation is often weaker. However, with appropriate parameter adjustments, the performance of a custom agent can gradually approach that of the built-in agent.
It should be noted that such parameter tuning typically increases the complexity of the custom agent. For example, improving performance may require adding more neurons or deeper layers in the neural network, increasing the learning frequency, or refining the reward function design—ideally in an adaptive manner that evolves with different training stages. In simple scenarios, the performance gap between custom agents and MATLAB’s built-in agents is not significant, but in complex problem settings, this gap becomes much more pronounced.
Furthermore, I have found that the learning frequency has a particularly strong impact on the performance of custom agents. Therefore, I recommend prioritizing adjustments to the learning frequency, followed by refining the reward function design, and finally optimizing the neural network architecture.
I hope this can be helpful to you.
  Vincent
 on 21 Aug 2025
				Thanks for the response, when you mean increasing learning frequency, do you mean increasing the perform actor updates less? So instead of updating every 2 timesteps, update every 10-20 timesteps? Also I wanted to ask did you use a rate transition block performing zero order hold between your custom agent and state,reward,done signal? Because in order to get my custom agent to compile I had to add a rate transition block and it performs a zero order hold and I'm suspicious that it may also be a reason for bad custom agent learning

See Also
Categories
				Find more on Applications in Help Center and File Exchange
			
	Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!


