離散cartpole環境が正常に学習しない
    2 views (last 30 days)
  
       Show older comments
    
「create custom environment from class template」を参考に離散cartpole環境を作成して、強化学習デザイナーにインポートさせてみました。
しかし、学習が安定に収束してくれませんでした。試行錯誤してみましたが、対処法が思いつきませんでした。
教えてください
classdef matlab < rl.env.MATLABEnvironment
    properties
        % Acceleration due to gravity in m/s^2
        Gravity = 9.8
        % Mass of the cart
        MassCart = 1.0
        % Mass of the pole
        MassPole = 0.1
        % Half the length of the pole
        Length = 0.5
        % Max Force the input can appy
        MaxForce = 10
        % Sample time
        Ts = 0.02
        % Angle at which to fail the episode
        ThetaThresholdRadians = 12 * pi/180
        % Distance at which to fail the episode
        XThreshold = 2.4
        % Reward each time step the cart-pole is balanced
        RewardForNotFalling = 1
        % Penalty when the cart-pole fails to balance
        PenaltyForFalling = -5 
    end
    properties
        % system state [x,dx,theta,dtheta]'
        State = zeros(4,1)
    end
    properties(Access = protected)
        % Internal flag to store stale env that is finished
        IsDone = false
    end
    properties (Transient,Access = private)
        Visualizer = []
    end
    methods                
        function this = matlab()%ObservationInfo, ActionInfo
            ObservationInfo = rlNumericSpec([4 1]);
            ObservationInfo.Name = 'CartPole States';
            ObservationInfo.Description = 'x, dx, theta, dtheta';
            ActionInfo = rlFiniteSetSpec([-1 1]);
            ActionInfo.Name = 'CartPole Action';
            this = this@rl.env.MATLABEnvironment(ObservationInfo, ActionInfo);
            updateActionInfo(this);
        end
        function set.State(this,state)
            validateattributes(state,{'numeric'},{'finite','real','vector','numel',4},'','State');
            this.State = double(state(:));
            notifyEnvUpdated(this);
        end
        function set.Length(this,val)
            validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Length');
            this.Length = val;
            notifyEnvUpdated(this);
        end
        function set.Gravity(this,val)
            validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Gravity');
            this.Gravity = val;
        end
        function set.MassCart(this,val)
            validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MassCart');
            this.MassCart = val;
        end
        function set.MassPole(this,val)
            validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MassPole');
            this.MassPole = val;
        end
        function set.MaxForce(this,val)
            validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MaxForce');
            this.MaxForce = val;
            updateActionInfo(this);
        end
        function set.Ts(this,val)
            validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Ts');
            this.Ts = val;
        end
        function set.ThetaThresholdRadians(this,val)
            validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','ThetaThresholdRadians');
            this.ThetaThresholdRadians = val;
            notifyEnvUpdated(this);
        end
        function set.XThreshold(this,val)
            validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','XThreshold');
            this.XThreshold = val;
            notifyEnvUpdated(this);
        end
        function set.RewardForNotFalling(this,val)
            validateattributes(val,{'numeric'},{'real','finite','scalar'},'','RewardForNotFalling');
            this.RewardForNotFalling = val;
        end
        function set.PenaltyForFalling(this,val)
            validateattributes(val,{'numeric'},{'real','finite','scalar'},'','PenaltyForFalling');
            this.PenaltyForFalling = val;
        end
        function [observation,reward,isdone,loggedSignals] = step(this,action)
             loggedSignals = [];
            % Get action
            force = getForce(this,action);            
            % Unpack state vector
            state = this.State;
            %x = state(1);
            x_dot = state(2);
            theta = state(3);
            theta_dot = state(4);
            % Apply motion equations
            costheta = cos(theta);
            sintheta = sin(theta);
            totalmass = this.MassCart + this.MassPole;
            polemasslength = this.MassPole*this.Length;
            temp = (force + polemasslength * theta_dot * theta_dot * sintheta) / totalmass;
            thetaacc = (this.Gravity * sintheta - costheta* temp) / (this.Length * (4.0/3.0 - this.MassPole * costheta * costheta / totalmass));
            xacc  = temp - polemasslength * thetaacc * costheta / totalmass;
            % Euler integration
            observation = state + this.Ts.*[x_dot;xacc;theta_dot;thetaacc];
            this.State = observation;
            x = observation(1);
            theta = observation(3);
            isdone = abs(x) > this.XThreshold || abs(theta) > this.ThetaThresholdRadians;
            this.IsDone = isdone;
            % Get reward
            reward = getReward(this,x,force);   
        end  
        function initialState = reset(this)
            % Randomize the initial pendulum angle between (+- .05 rad)
            % Theta (+- .05 rad)
            T0 = 2*0.05*rand - 0.05;  
            % Thetadot
            Td0 = 0;
            % X 
            X0 = 0;
            % Xdot
            Xd0 = 0;
            initialState= [X0;Xd0;T0;Td0];
            this.State = initialState;
        end        
        function varargout = plot(this)
            % Visualizes the environment
            if isempty(this.Visualizer) || ~isvalid(this.Visualizer)
                this.Visualizer = rl.env.viz.CartPoleVisualizer(this);
            else
                bringToFront(this.Visualizer);
            end
            if nargout
                varargout{1} = this.Visualizer;
            end
        end
    end
    methods (Access = protected)
        function force = getForce(this,action)
            if ~ismember(action,this.ActionInfo.Elements)
                error(message('rl:env:CartPoleDiscreteInvalidAction',sprintf('%g',-this.MaxForce),sprintf('%g',this.MaxForce)));
            end
            force = action;           
        end
        % update the action info based on max force
        function updateActionInfo(this)
            this.ActionInfo.Elements = this.MaxForce*[-1 10];
        end
        function Reward = getReward(this,~,~)
            if ~this.IsDone
                Reward = this.RewardForNotFalling;
            else
                Reward = this.PenaltyForFalling;
            end
        end
    end  
end
0 Comments
Accepted Answer
More Answers (0)
See Also
Categories
				Find more on 方程式の解法 in Help Center and File Exchange
			
	Products
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!