trying perform an image classification using RL but facing error, Any suggestion?

14 views (last 30 days)
clc;
%% Load Dataset
imds = imageDatastore('D:\2023\thesis\waveletImages\SubsetImages', ...
'IncludeSubfolders', true, ...
'LabelSource', 'foldernames');
% Ensure the dataset is balanced
DataSetInfo = countEachLabel(imds);
% Split Data into Training and Validation sets
[imdsTrain, imdsValidation] = splitEachLabel(imds, 0.80, 'randomized');
%% Load Pretrained Network
net = vgg16;
inputSize = net.Layers(1).InputSize;
% Feature extraction layer
featureLayer = 'fc7';
% Augment data to match input size of the network
augimdsTrain = augmentedImageDatastore(inputSize, imdsTrain);
augimdsValidation = augmentedImageDatastore(inputSize, imdsValidation);
% Extract Features using Pretrained CNN
trainFeatures = activations(net, augimdsTrain, featureLayer, 'OutputAs', 'rows');
validationFeatures = activations(net, augimdsValidation, featureLayer, 'OutputAs', 'rows');
% Extract labels
trainLabels = grp2idx(imdsTrain.Labels);
validationLabels = grp2idx(imdsValidation.Labels);
%% RL Environment Setup
% Define Observation and Action spaces
numFeatures = size(trainFeatures, 2); % Number of features (e.g., 4096 for 'fc7' layer)
ObservationInfo = rlNumericSpec([numFeatures, 1], ...
'LowerLimit', -inf, 'UpperLimit', inf, ...
'Name', 'Features');
ObservationInfo.Description = 'Feature vector extracted from CNN';
% Define Action space (each action corresponds to a class label)
ActionInfo = rlFiniteSetSpec(1:max(trainLabels)); % Class labels as actions
ActionInfo.Name = 'Class Labels';
% Define Step and Reset functions
stepFunction = @(Action, LoggedSignals) stepFunctionRL(Action, trainFeatures, trainLabels, LoggedSignals);
resetFunction = @() resetFunctionRL(trainFeatures);
% Create RL Environment
env = rlFunctionEnv(ObservationInfo, ActionInfo, stepFunction, resetFunction);
%% Define RL Agent
% Define critic network
statePath = [
featureInputLayer(numFeatures, 'Normalization', 'none', 'Name', 'state')
fullyConnectedLayer(128, 'Name', 'fc1')
reluLayer('Name', 'relu1')
fullyConnectedLayer(64, 'Name', 'fc2')
reluLayer('Name', 'relu2')
fullyConnectedLayer(numel(ActionInfo.Elements), 'Name', 'fcOutput')]; % Number of actions
criticNet = dlnetwork(layerGraph(statePath));
critic = rlQValueFunction(criticNet, ObservationInfo, ActionInfo, 'ObservationInputNames', 'state');
% Agent options
agentOpts = rlDQNAgentOptions( ...
'ExperienceBufferLength', 1e5, ...
'DiscountFactor', 0.99, ...
'MiniBatchSize', 64, ...
'TargetSmoothFactor', 1e-3, ...
'TargetUpdateFrequency', 4);
agent = rlDQNAgent(critic, agentOpts);
%% Train RL Agent
trainOpts = rlTrainingOptions( ...
'MaxEpisodes', 500, ...
'MaxStepsPerEpisode', size(trainFeatures, 1), ...
'Verbose', true, ...
'Plots', 'training-progress', ...
'StopTrainingCriteria', 'EpisodeCount', ...
'StopTrainingValue', 500);
trainingStats = train(agent, env);
%% Reset Function
function [InitialObservation, LoggedSignals] = resetFunctionRL(Features)
% Reset function for the RL environment
% This function resets the environment to its initial state at the start
% of each new episode.
% Initialize the LoggedSignals structure (holds environment state)
LoggedSignals = struct();
% Set the initial observation index to the first data point
LoggedSignals.CurrentIndex = 1;
% Fetch the initial observation: the first feature vector
% Ensure that the observation is formatted as a column vector [numFeatures, 1]
InitialObservation = Features(LoggedSignals.CurrentIndex, :)'; % Transpose to column vector
% Check if the initial observation matches the expected size
if size(InitialObservation, 2) ~= 1
error('InitialObservation must be a column vector of size [numFeatures, 1]');
end
% Optionally, initialize any other properties you want in LoggedSignals
% (e.g., additional state variables, environment-specific flags)
LoggedSignals.EpisodeStartTime = datetime('now'); % Track the start time for each episode
end
%% Step Function
function [NextObservation, Reward, IsDone, LoggedSignals] = stepFunctionRL(Action, Features, Labels, LoggedSignals)
% Step function for the RL environment
idx = LoggedSignals.CurrentIndex;
% Assign reward based on action correctness
correctLabel = Labels(idx); % Correct label for the current feature
if Action == correctLabel
Reward = 1; % Positive reward for correct classification
else
Reward = -1; % Negative reward for incorrect classification
end
% Update index and check if episode is done
LoggedSignals.CurrentIndex = idx + 1;
IsDone = LoggedSignals.CurrentIndex > size(Features, 1);
% Ensure the next observation is a column vector of size [numFeatures, 1]
if ~IsDone
NextObservation = Features(LoggedSignals.CurrentIndex, :)'; % Transpose to ensure column vector [numFeatures x 1]
else
NextObservation = Features(idx, :)'; % Dummy to avoid dimension mismatch
end
end
AND THIS IS THE ERROR
Error using rl.env.MATLABEnvironment/validateEnvironment (line 58)
Environment 'ObservationInfo' does not match observation output from reset function. Check the data type, dimensions, and range.
Error in rl.env.rlFunctionEnv (line 82)
validateEnvironment(this);
Error in rlFunctionEnv (line 45)
env = rl.env.rlFunctionEnv(varargin{:});

Answers (1)

Drew
Drew on 8 Jan 2025 at 16:28
Edited: Drew on 8 Jan 2025 at 16:28
Sending your question to a Generative AI system resulted in the following output. The GenAI suggested some "Key Checks and Fixes" to look at in an attempt to debug this mismatch issue. The GenAI suggestions may or may not be correct. You could also try varying the questions that you pose to Generative AI, to iterate on the solution. If this answer helps you, please remember to accept the answer.
From generative AI:
The error message you're encountering indicates a mismatch between the dimensions or data types of the observation output from your reset function and what is expected by the ObservationInfo specification in your reinforcement learning environment. Let's go through the code to ensure that the observation from the resetFunctionRL matches the ObservationInfo.
Key Checks and Fixes:
1.ObservationInfo Specification:
numFeatures = size(trainFeatures, 2); % Number of features (e.g., 4096 for 'fc7' layer)
ObservationInfo = rlNumericSpec([numFeatures, 1], ...
'LowerLimit', -inf, 'UpperLimit', inf, ...
'Name', 'Features');
This specifies that the observation should be a column vector with numFeatures rows.
2. Reset Function:
function [InitialObservation, LoggedSignals] = resetFunctionRL(Features)
LoggedSignals = struct();
LoggedSignals.CurrentIndex = 1;
InitialObservation = Features(LoggedSignals.CurrentIndex, :)'; % Transpose to column vector
end
Make sure that InitialObservation is indeed a column vector. You've already transposed it, but let's ensure that Features is correctly shaped before the transpose.
3. Potential Issues and Fixes:
  • Check the Shape of trainFeatures: Ensure that trainFeatures is a 2D matrix where each row corresponds to a feature vector. The transpose operation should result in a column vector of size [numFeatures, 1].
  • Debugging Step: Add some debugging statements to verify the dimensions:
disp(size(InitialObservation)); % Should output [numFeatures, 1]
4. Validation with ObservationInfo:
  • Ensure that ObservationInfo and InitialObservation have compatible dimensions. If ObservationInfo expects [numFeatures, 1], make sure InitialObservation matches this.
5. Check for Consistency Across Functions:
  • Ensure that the NextObservation in the stepFunctionRL also matches the expected dimensions:
if ~IsDone
NextObservation = Features(LoggedSignals.CurrentIndex, :)';
else
NextObservation = Features(idx, :)'; % Dummy to avoid dimension mismatch
end
Final Verification:
  • After making sure the dimensions match, re-run your script. If the error persists, add additional debug statements to print out the dimensions of InitialObservation and NextObservation right before they are returned to verify that they match ObservationInfo.
By ensuring these checks, you should be able to resolve the mismatch error. If you continue to experience issues, consider simplifying the environment to isolate the problem and ensure dimensions are being handled consistently throughout your code.

Products


Release

R2024b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!