Clear Filters
Clear Filters

How to predict using trained neural network obtained from Classification Learner app

7 views (last 30 days)
My output tained neural network model name is trainedModel. I want to perform prediction by using weights and biases which i acquired using following lines of code
wi=trainedModel.ClassificationNeuralNetwork.LayerWeights{1,1};
wo=trainedModel.ClassificationNeuralNetwork.LayerWeights{1,2};
bi=trainedModel.ClassificationNeuralNetwork.LayerBiases{1,1};
bo=trainedModel.ClassificationNeuralNetwork.LayerBiases{1,2};
Then I perform the prediction task on the input features using the network predictFcn
indice_net=100;
a1=trainedModel.predictFcn(trainedModel.ClassificationNeuralNetwork.X{indice_net,:});
But when I perform the prediction using the following code the result is some times different from the output I get from predictFcn in the trained model. Can you please guide me how can I replace the functionality of predictFcn in simple MATLAB code as below? Or if I am doing some mistake?
x=trainedModel.ClassificationNeuralNetwork.X{indice_net,:};
lay1=wi*x';
lay1_b=lay1+bi;
y = sigmoid(lay1_b);
lay2=wo*y;
lay2_b=lay2+bo;
predict_class=softmax(lay2_b);

Answers (1)

Drew
Drew on 7 Feb 2023
It looks like you are trying to re-implement the "predict" method for a neural network model obtained from Classification Learner.
If this answer helps you, please remember to "Accept" this answer.
Some pointers toward a solution:
(1) As you point out, a predictFcn method is already provided. It might help to know why you want to re-implement the predict method. This may suggest another solution.
(2) You say that the output from your code and the predictFcn are "some times different". Can you elaborate on when the results are the same, and when they are different? When there are differences, what is the numerical size of the differences? Are they just small differences? In general, small differences could arise due to differences in floating point computations between your code and the built-in predict method. If you can provide your exported model (with training data included in the X property), and which inputs you are seeing differences, more specific feedback could be provided.
(3) Did you use feature selection or PCA in Classification Learner? If so, you will need to account for those steps when testing the model. "trainedModel.predictFcn" is a wrapper around the model-specific predict function "trainedModel.ClassificationNeuralNetwork.predict". The predictFcn wrapper is there to handle feature selection and PCA prior to calling the "trainedModel.ClassificationNeuralNetwork.predict" method. For documentation of the ClassificationNeuralNetwork model-specific predict function, see: https://www.mathworks.com/help/stats/classificationneuralnetwork.predict.html
(4) Remember to examine the model, including its methods and properties, to ensure you are correctly taking into account all aspects of the model. Did you check which type of "Activations" your model is using? By default, the "ReLU" activation is used in Classification Learner, but you can select to change the activation type to "Sigmoid" or other. For example for a bilayered Neural Network model exported from Classification Learner based on fisheriris data, where "Sigmoid" activation was chosen in Classification Learner before the model was trained, we see the following when examining the model at the command line:
>> trainedModel.ClassificationNeuralNetwork
ans =
ClassificationNeuralNetwork
PredictorNames: {'SepalLength' 'SepalWidth' 'PetalLength' 'PetalWidth'}
ResponseName: 'Y'
CategoricalPredictors: []
ClassNames: {'setosa' 'versicolor' 'virginica'}
ScoreTransform: 'none'
NumObservations: 135
LayerSizes: [10 10]
Activations: 'sigmoid'
OutputLayerActivation: 'softmax'
Solver: 'LBFGS'
ConvergenceInfo: [1×1 struct]
TrainingHistory: [64×7 table]
>> properties(trainedModel.ClassificationNeuralNetwork)
Properties for class ClassificationNeuralNetwork:
ConvergenceInfo
Solver
TrainingHistory
Y
X
RowsUsed
W
ModelParameters
NumObservations
BinEdges
HyperparameterOptimizationResults
PredictorNames
CategoricalPredictors
ResponseName
ExpandedPredictorNames
ClassNames
Prior
Cost
ScoreTransform
OutputLayerActivation
LayerSizes
Activations
LayerWeights
LayerBiases
>> methods(trainedModel.ClassificationNeuralNetwork)
Methods for class ClassificationNeuralNetwork:
compact edge partialDependence resubEdge resubPredict
compareHoldout loss plotPartialDependence resubLoss
crossval margin predict resubMargin

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!