Fitnet: Why do some training functions (trainbfg,traincgf) with default settings fail on the simplefit dataset, while others (trainlm,trainbr) work perfectly well?
4 views (last 30 days)
Show older comments
I have written a script that compares various training functions with their default parameters, using the data returned by simplefit_dataset. I train the networks on half of the points and evaluate the performance on all points. trainlm works well, trainbr works very well, but trainbfg, traincgf and trainrp do not work at all. What is the reason for this? I would have thought that the default parameters were good enough for this almost trivial problem.
Here is the code:
%% Function fitting with a neural network
clear
close all
clc
rng('default') % choose default random number generator
rng(0) % set seed of RNG
% try several architectures and training/optimization functions
% see fitnet documentation
hidsize=8; % number of hidden neurons
mynets{1}=fitnet(hidsize,'trainlm'); % Levenberg-Marquardt (default)
mynets{2}=fitnet(hidsize,'trainbr'); % Bayesian regularization
mynets{3}=fitnet(hidsize,'trainbfg'); % BFGS quasi-Newton
mynets{4}=fitnet(hidsize,'traincgf'); % Fletcher-Powell conjugate gradients
mynets{5}=fitnet(hidsize,'trainrp'); % Resilient backpropagation
% load data and targets
[x,t] = simplefit_dataset;
xx=x(1:2:end); % training inputs
tt=t(1:2:end); % training targets
numnets=length(mynets);
for inet=1:numnets
net=mynets{inet};
net.trainParam.showCommandLine=false;
net.trainParam.showWindow=false;
net=train(net,xx,tt);
%view(net)
y=sim(net,x);
mu=mean(y-t);
sig=std(y-t);
makeplots=false; % set true to see figures
if makeplots
figure(inet);
ha=plot(x,t,'.-g','displayname','all values'); % test values
hold on
ylim([0 12])
ht=plot(xx,tt,'.-k','displayname','training values');
hold on
hy=plot(x,y,'o-r','displayname','network output');
xlabel('$x$')
ylabel('$y$')
titstr=['fitnet with ',num2str(hidsize),...
' hidden neurons and training function ',net.trainfcn];
title(titstr)
legend([ha,ht,hy]);
hx=text(1,11,['$\mu=',num2str(mu,2),',\ \sigma=',num2str(sig,2),'$']);
end
Mu(inet)=mu;
Sig(inet)=sig;
end
disp(['Mean of errors: ',num2str(Mu)])
disp(['Stdev of errors: ',num2str(Sig)])
1 Comment
S0852306
on 26 Jul 2023
Edited: S0852306
on 26 Jul 2023
Hi Rudolf,
First of all, there are some bugs in your code:
y=sim(net,x);
mu=mean(y-t); %.^2
the default cost function is MSE (Mean Square Error), so the correct code should be:
y=sim(net,x); e=y-t;
mu=mean(e.^2); % or use mu=perform(net,y,t)
Mean of errors: 0.00053693 8.3758e-10 0.13737 0.33484 0.14557
Stdev of errors: 0.0018166 3.8184e-09 0.18903 0.44031 0.29932
Secondly, if your goal is to compare the performance of optimization solvers, use only training data to calculate MSE,
not the entire data set.
y=sim(net,xx); e=y-tt; % use training set to compute MSE
mu=mean(e.^2);
"trainlm works well, trainbr works very well, but trainbfg, traincgf and trainrp do not work at all. What is the reason for this? "
About this question, there are two main reasons:
Cost function
Training neural nets for surface fitting task is just solving a nonlinear least square problem, however, the cost functions posed by neural nets are extremely complex (non-convex and extremely large condition number), without curvature information (Hessian), it's hard for gradient descent (GD)-based method converges to the minimum, in theory, trainlm and trainbfg shold outperform other solvers. (other solvers are GD-Based methods, except for trainbr, I don't know its algorithm)
Optimization algorithm
trainlm is a "hybrid" method that combines Gauss-Newton (G-N) method and GD. GD-based methods are usaully more robust than Newton-type methods (but extremely slow in terms of iterations compare to Newton-type methods ). Although G-N / Quasi-Newton methods (BFGS) are much faster (quadratic / superlinear convergence), they are sensitive to initial point, if your initial is not "good enough" , you may get stuck by local min and saddle point, I guess that's the main reason why BFGS fail. In the early stage of LM, the algorithm usaully perform more like GD, so it kind of using GD to provide a better "initial point", than performing G-N to achieve fast convergence.
So if your network is not too large, for most cases using "trainlm" will be the best choice.
(However, LM is not feasible for larger neural nets, because it requires to solving a linear system at each iteration.)
So, is there exist some robust and fast method for trainning medium-size neural nets?
(medium-size networks are very common in scientific ML applications)
Some papers provide useful suggestions. that is, train the net using Stochastic Gradient Descent-based methods ( SGD) first, then use BFGS or L-BFGS solver, this will be faster than pure LM. In my experience, two-stage optimization usaully very robust and works quite well, check out the following function approximation example, it's much more complex than simplefit_dataset.
SSE: 0.14
MAE: 0.003
clear; clc; close all;
% To run this script, download the pack
% at file exchange: https://tinyurl.com/wre9r5uk
%% Generate data
n=50;
x=linspace(-2,2,n);
y=linspace(-2,2,n);
n1=numel(x); n2=n1;
count=0;
for j=1:n2
for k=1:n1
count=count+1;
data(1,count)=x(j);
data(2,count)=y(k);
u=x(j)^2+y(k)^2;
label(1,count)=(log(1+(x(k)-4/3)^2+3*(x(k)+y(j)-x(k)^3)^2));
label(2,count)=exp(-u/2).*cos(2*u);
end
end
%% Network Structure Set Up
InputDimension=2; OutputDimension=2;
LayerStruct=[InputDimension,10,10,10,OutputDimension];
NN.Cost='SSE';
NN.labelAutoScaling='on';
NN=Initialization(LayerStruct,NN);
%% First Order Solver Set Up
option.Solver='ADAM'; % ADAM is the state-of-the-art SGD-based solver.
option.s0=1e-3; % step size
option.MaxIteration=250;
option.BatchSize=1000;
NN=OptimizationSolver(data,label,NN,option);
%% Quasi-Newton Solver Set Up
option.Solver='BFGS';
option.MaxIteration=500;
NN=OptimizationSolver(data,label,NN,option);
%% Validation
Prediction=NN.Evaluate(data);
Error=label-Prediction;
Report=FittingReport(data,label,NN);
%% Visualization
figure
slice=2;
scatter3(data(1,:),data(2,:),label(slice,:),'black')
hold on
[X,Y]=meshgrid(x,y);
n1=numel(x); n2=numel(y);
surf(X,Y,reshape(Prediction(slice,:),n1,n2))
title('Neural Network Fit')
legend('data','Fitting')
Answers (1)
Aditya
on 24 Jan 2023
Edited: Aditya
on 3 Feb 2023
Hi,
I understand that you want to know why some network training functions are performing poorly as compared to others.
The evaluation strategy might not give you the correct picture of how the network performs. As you are training the network on half the points, you cannot use those points in evaluation because network has already "seen" those points. In some network, suppose the network is overfitting on the training dataset, you will get very good results when evaluating. There should be no overlap in the training and testing dataset.
Here is the same example with different training and testing splits.
[x,y] = simplefit_dataset;
train_x = x(1:80);
test_x = x(81:end);
train_y = y(1:80);
test_y = y(81:end);
nets = {'trainlm','trainbr','trainbfg','traincgf','trainrp'};
hidsize = 8;
for i =1:5
net = fitnet(hidsize, nets{i});
net.trainParam.showWindow=false;
net=train(net,train_x, train_y);
Y = sim(net,test_x);
perf = perform(net,Y,test_y);
disp(nets{i} + " Mean Squared Error = " + perf);
end
Your observation regarding the performance of the network with 8 hidden layers is correct. However, if you change the number of hidden layers to 15, for example, you will notice different results.
With deep learning, the goal is to have generalization. So, for comparing training functions, you should perform the comparison on different network sizes to come at a conclusion.
2 Comments
Aditya
on 3 Feb 2023
Edited: Aditya
on 3 Feb 2023
It might be okay to check if training data loss is going down or not to evaluate whether network is learning.
However, using training dataset to compare the performance of networks might not give a complete picture. The simple reason is overfitting networks would perform better. Generalization matters.
Yes, you are right in the observation that different weight initialization lead to faster convergence. The initial point can determine whether the algorithm converges at all, with some initial points being so unstable that the algorithm encounters numerical difficulties and fails altogether.
You can read about it here: Weight Initialization for Deep Learning Neural Networks - MachineLearningMastery.com
See Also
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!