imds = imageDatastore('D:\2023\thesis\waveletImages\SubsetImages', ...
'IncludeSubfolders', true, ...
'LabelSource', 'foldernames');
DataSetInfo = countEachLabel(imds);
[imdsTrain, imdsValidation] = splitEachLabel(imds, 0.80, 'randomized');
inputSize = net.Layers(1).InputSize;
augimdsTrain = augmentedImageDatastore(inputSize, imdsTrain);
augimdsValidation = augmentedImageDatastore(inputSize, imdsValidation);
trainFeatures = activations(net, augimdsTrain, featureLayer, 'OutputAs', 'rows');
validationFeatures = activations(net, augimdsValidation, featureLayer, 'OutputAs', 'rows');
trainLabels = grp2idx(imdsTrain.Labels);
validationLabels = grp2idx(imdsValidation.Labels);
numFeatures = size(trainFeatures, 2);
ObservationInfo = rlNumericSpec([numFeatures, 1], ...
'LowerLimit', -inf, 'UpperLimit', inf, ...
ObservationInfo.Description = 'Feature vector extracted from CNN';
ActionInfo = rlFiniteSetSpec(1:max(trainLabels));
ActionInfo.Name = 'Class Labels';
stepFunction = @(Action, LoggedSignals) stepFunctionRL(Action, trainFeatures, trainLabels, LoggedSignals);
resetFunction = @() resetFunctionRL(trainFeatures);
env = rlFunctionEnv(ObservationInfo, ActionInfo, stepFunction, resetFunction);
featureInputLayer(numFeatures, 'Normalization', 'none', 'Name', 'state')
fullyConnectedLayer(128, 'Name', 'fc1')
reluLayer('Name', 'relu1')
fullyConnectedLayer(64, 'Name', 'fc2')
reluLayer('Name', 'relu2')
fullyConnectedLayer(numel(ActionInfo.Elements), 'Name', 'fcOutput')];
criticNet = dlnetwork(layerGraph(statePath));
critic = rlQValueFunction(criticNet, ObservationInfo, ActionInfo, 'ObservationInputNames', 'state');
agentOpts = rlDQNAgentOptions( ...
'ExperienceBufferLength', 1e5, ...
'DiscountFactor', 0.99, ...
'TargetSmoothFactor', 1e-3, ...
'TargetUpdateFrequency', 4);
agent = rlDQNAgent(critic, agentOpts);
trainOpts = rlTrainingOptions( ...
'MaxStepsPerEpisode', size(trainFeatures, 1), ...
'Plots', 'training-progress', ...
'StopTrainingCriteria', 'EpisodeCount', ...
'StopTrainingValue', 500);
trainingStats = train(agent, env);
function [InitialObservation, LoggedSignals] = resetFunctionRL(Features)
LoggedSignals = struct();
LoggedSignals.CurrentIndex = 1;
InitialObservation = Features(LoggedSignals.CurrentIndex, :)';
if size(InitialObservation, 2) ~= 1
error('InitialObservation must be a column vector of size [numFeatures, 1]');
LoggedSignals.EpisodeStartTime = datetime('now');
function [NextObservation, Reward, IsDone, LoggedSignals] = stepFunctionRL(Action, Features, Labels, LoggedSignals)
idx = LoggedSignals.CurrentIndex;
correctLabel = Labels(idx);
if Action == correctLabel
LoggedSignals.CurrentIndex = idx + 1;
IsDone = LoggedSignals.CurrentIndex > size(Features, 1);
NextObservation = Features(LoggedSignals.CurrentIndex, :)';
NextObservation = Features(idx, :)';