- classification: https://www.mathworks.com/help/stats/classification.html
- confusionchart: https://in.mathworks.com/help/stats/confusionchart.html
Gaussian Naive Bayes classification
4 views (last 30 days)
Show older comments
I have found the following Matlab implementation of a Naive Bayes classifier:
https://github.com/jjedele/Naive-Bayes-Classifier-Octave-Matlab
How can I extend the above implementation to become Gaussian Naive Bayes?
How can I extend the implementation for using it with 4 classes? Just doing one-vs-all other?
Thank you very much for the help.
0 Comments
Answers (1)
Abhipsa
on 1 Sep 2025 at 6:31
There is no need to switch to one-vs-all for Naive Bayes as it’s naturally a multiclass classifier.
In Gaussian Naive Bayes, each feature is assumed to follow a normal distribution within each class.
You can adapt the files from repository(https://github.com/yzhou/Naive-Bayes-Classifier-Octave-Matlab) as shown below:
function model = gnb_train(X, y)
% X: n-by-d matrix (rows = samples, cols = features)
% y: n-by-1 vector of class labels (numeric/char/string)
if iscell(y), y = string(y); end
classes = unique(y);
[n,d] = size(X); K = numel(classes);
mu = zeros(K,d); varc = zeros(K,d); prior = zeros(1,K);
for k = 1:K
idx = (y == classes(k));
Xk = X(idx,:);
prior(k) = sum(idx)/n; % P(y=k)
mu(k,:) = mean(Xk,1); % feature means
varc(k,:) = var(Xk,1); % MLE variances (normalize by N)
end
% Variance floor (avoid divide-by-zero if a feature is constant)
varc = max(varc, 1e-9);
model.classes = classes;
model.prior = prior;
model.mu = mu;
model.var = varc;
end
function [yhat, logpost] = gnb_predict(model, X)
% X: m-by-d test matrix
% yhat: m-by-1 predicted labels
% logpost: m-by-K unnormalized log-posteriors (diagnostics)
[m,~] = size(X);
K = numel(model.classes);
logpost = zeros(m,K);
% Per-class constant term for diagonal Gaussian
const = -0.5 * sum(log(2*pi*model.var), 2); % K-by-1
logprior = log(model.prior); % 1-by-K
for k = 1:K
diff = X - model.mu(k,:);
quad = (diff.^2) ./ model.var(k,:); % m-by-d
ll = const(k) - 0.5 * sum(quad, 2); % m-by-1
logpost(:,k) = ll + logprior(k);
end
[~, idx] = max(logpost, [], 2);
yhat = model.classes(idx);
end
The above two functions can be run by using a synthesis data as shown in the below script:
% Seed for reproducibility (optional)
rng(7);
% ---- 1) Make a simple 4-class 2D dataset (Gaussian blobs)
nPerClass = 150;
C1 = mvnrnd([0, 0], [0.5 0; 0 0.3], nPerClass);
C2 = mvnrnd([3, 1], [0.6 0; 0 0.6], nPerClass);
C3 = mvnrnd([-2, 3], [0.4 0; 0 0.7], nPerClass);
C4 = mvnrnd([2, -2], [0.7 0; 0 0.4], nPerClass);
X = [C1; C2; C3; C4];
y = [ones(nPerClass,1);
2*ones(nPerClass,1);
3*ones(nPerClass,1);
4*ones(nPerClass,1)];
% ---- 2) Standardize features (recommended for GNB)
muX = mean(X,1); sigX = std(X,0,1); sigX(sigX==0) = 1;
Xz = (X - muX) ./ sigX;
% ---- 3) Train/test split (70/30)
n = size(Xz,1);
idx = randperm(n);
nTr = round(0.7*n);
tr = idx(1:nTr);
te = idx(nTr+1:end);
Xtr = Xz(tr,:); ytr = y(tr);
Xte = Xz(te,:); yte = y(te);
% ---- 4) Train Gaussian Naive Bayes
model = gnb_train(Xtr, ytr);
% ---- 5) Predict
[yhat, scores] = gnb_predict(model, Xte);
% ---- 6) Accuracy and confusion matrix
acc = mean(yhat == yte);
fprintf('Test accuracy: %.2f%%\n', 100*acc);
K = numel(unique(y));
CM = zeros(K);
for i = 1:numel(yte)
CM(yte(i), yhat(i)) = CM(yte(i), yhat(i)) + 1;
end
disp('Confusion matrix (rows = true, cols = predicted):');
disp(CM);
% ---- 7) (Optional) quick scatter to visualize test set predictions
figure; hold on;
clsColors = lines(4); % 4 distinct colors
for c = 1:4
pts = (yte == c);
scatter(Xte(pts,1), Xte(pts,2), 25, clsColors(c,:), 'filled', ...
'MarkerFaceAlpha', 0.7, 'DisplayName', sprintf('True C%d', c));
end
for c = 1:4
pts = (yhat == c);
scatter(Xte(pts,1), Xte(pts,2), 10, 'k', 'o'); % black rings for predicted class
end
title(sprintf('Gaussian NB on 4 classes (Test acc = %.1f%%)', 100*acc));
xlabel('z-score feature 1'); ylabel('z-score feature 2'); grid on; box on; hold of
You can refer to the following MATLAB documentations for more details:
0 Comments
See Also
Categories
Find more on Classification in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!