Poor performance of trainNetwork() function as compared to train()

8 views (last 30 days)
Hello, I am having issues with training a neural network using trainNetwork() as compared to train() and I am stumped. I tried to set up an identical network architecture, the first using fitnet() for train() function and the second using the toolbox for trainNetwork. The train() function converges rapidly and gets to a good solution, while I am currently unable to get the trainNetwork() to converge decently. What am I doing wrong?
Code 1: Working with train()
% Select if new or old network
if useExisting == false
fprintf("Option A: Train New Model\n");
net = fitnet(hidden_layer_size); % Define network
else
fprintf("Option B: Load and Fine-tune Existing Model\n");
net = loadedData.net; % Load the existing neural network
end
% Set the common network training parameters
net.trainFcn = 'trainscg';
net.trainParam.epochs = 250E3;
net.trainParam.goal = 0;
net.trainParam.max_fail = 6;
% Divide data for training and testing (80:20 split)
net.divideParam.trainRatio = 0.8;
net.divideParam.testRatio = 0.1;
net.divideParam.valRatio = 0.1; % Disable validation as we're focusing on training and testing
% Train the neural network
net = train(net, input, target, 'usegpu','yes'); % Training with shuffled input and target
Code 2: Not working
options = trainingOptions('rmsprop', ...
'MaxEpochs', 250E3, ... % Keep epochs reasonable
'MiniBatchSize', 32, ... % Large batch for stable updates
'InitialLearnRate', 1E-3, ... % Lower LR since RMSprop adapts per-parameter
'SquaredGradientDecayFactor', 0.85, ... % Default is 0.99; try 0.9 for faster adaptation
'Shuffle', 'every-epoch', ... % Keep it since data has overlap
'ValidationData', {input(:, 1:round(0.2 * end)), target(:, 1:round(0.2 * end))}, ...
'ValidationFrequency', 50, ... % Check validation every 50 mini-batches
'Verbose', true, ...
'Plots', 'training-progress', ...
'ValidationPatience', 12, ... % More patience for slow convergence
'ExecutionEnvironment', 'gpu'); % Use GPU for speed
% Adjust layers for 1D data
layers = [
sequenceInputLayer(chunk_size, 'Name', 'input') % Adjusted for 1D data
fullyConnectedLayer(hidden_layer_size, 'Name', 'fc1')
tanhLayer
fullyConnectedLayer(chunk_size, 'Name', 'output') % Adjust output size if needed
regressionLayer('Name', 'regression')];
% Select if new or old network
if useExisting
fprintf("Option B: Load and Fine-tune Existing Model\n");
net = trainNetwork(input, target, net.Layers, options);
else
fprintf("Option A: Train New Model\n");
net = trainNetwork(input, target, layers, options);
end
I have tried playing with the training parameters in the trainNetwork and this is the best I was able to set. Unfortunately, the performance is dismal compared to the train().
  3 Comments
Ivan Rodionov
Ivan Rodionov on 8 Feb 2025
Edited: Ivan Rodionov on 8 Feb 2025
@Matt J Hello Matt and thank you for your reply. I guess in some sense you are right, however if I am understanding the code correctly, it should not be the case because both codes are fundamentally doing the same thing? A single hidden layer slightly wide neural network with identical tanh activation function. Be it made via fitnet() or layers, it should fundamentally not work with one and not at all with the other unless something is broken and or I am misunderstanding what is going on behind the scenes?
EDIT:
If you are curious, I can gladly provide the datasets, it is a distorted signal.
Matt J
Matt J on 9 Feb 2025
Edited: Matt J on 9 Feb 2025
But the algorithm used by trainscg() is different, and has fewer tuning parameters than rmsprop. We don't know how performance might improve if you changed the InitialLearnRate, MiniBatchSize, and other rmsprop parameters. You might try Adam instead of RmsProp. I've heard it is more robust.

Sign in to comment.

Answers (1)

Matt J
Matt J on 9 Feb 2025
Edited: Matt J on 9 Feb 2025
I have tried playing with the training parameters in the trainNetwork and this is the best I was able to set
You can try using the Experiment Manager to explore the hyperparameter space more systematically,

Categories

Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange

Products


Release

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!