Gaussian Naive Bayes classification

4 views (last 30 days)
Sepp
Sepp on 12 Jun 2015
Answered: Abhipsa on 1 Sep 2025 at 6:31
I have found the following Matlab implementation of a Naive Bayes classifier:
https://github.com/jjedele/Naive-Bayes-Classifier-Octave-Matlab
How can I extend the above implementation to become Gaussian Naive Bayes?
How can I extend the implementation for using it with 4 classes? Just doing one-vs-all other?
Thank you very much for the help.

Answers (1)

Abhipsa
Abhipsa on 1 Sep 2025 at 6:31
Hello @Sepp,
There is no need to switch to one-vs-all for Naive Bayes as it’s naturally a multiclass classifier.
In Gaussian Naive Bayes, each feature is assumed to follow a normal distribution within each class.
You can adapt the files from repository(https://github.com/yzhou/Naive-Bayes-Classifier-Octave-Matlab) as shown below:
function model = gnb_train(X, y)
% X: n-by-d matrix (rows = samples, cols = features)
% y: n-by-1 vector of class labels (numeric/char/string)
if iscell(y), y = string(y); end
classes = unique(y);
[n,d] = size(X); K = numel(classes);
mu = zeros(K,d); varc = zeros(K,d); prior = zeros(1,K);
for k = 1:K
idx = (y == classes(k));
Xk = X(idx,:);
prior(k) = sum(idx)/n; % P(y=k)
mu(k,:) = mean(Xk,1); % feature means
varc(k,:) = var(Xk,1); % MLE variances (normalize by N)
end
% Variance floor (avoid divide-by-zero if a feature is constant)
varc = max(varc, 1e-9);
model.classes = classes;
model.prior = prior;
model.mu = mu;
model.var = varc;
end
function [yhat, logpost] = gnb_predict(model, X)
% X: m-by-d test matrix
% yhat: m-by-1 predicted labels
% logpost: m-by-K unnormalized log-posteriors (diagnostics)
[m,~] = size(X);
K = numel(model.classes);
logpost = zeros(m,K);
% Per-class constant term for diagonal Gaussian
const = -0.5 * sum(log(2*pi*model.var), 2); % K-by-1
logprior = log(model.prior); % 1-by-K
for k = 1:K
diff = X - model.mu(k,:);
quad = (diff.^2) ./ model.var(k,:); % m-by-d
ll = const(k) - 0.5 * sum(quad, 2); % m-by-1
logpost(:,k) = ll + logprior(k);
end
[~, idx] = max(logpost, [], 2);
yhat = model.classes(idx);
end
The above two functions can be run by using a synthesis data as shown in the below script:
% Seed for reproducibility (optional)
rng(7);
% ---- 1) Make a simple 4-class 2D dataset (Gaussian blobs)
nPerClass = 150;
C1 = mvnrnd([0, 0], [0.5 0; 0 0.3], nPerClass);
C2 = mvnrnd([3, 1], [0.6 0; 0 0.6], nPerClass);
C3 = mvnrnd([-2, 3], [0.4 0; 0 0.7], nPerClass);
C4 = mvnrnd([2, -2], [0.7 0; 0 0.4], nPerClass);
X = [C1; C2; C3; C4];
y = [ones(nPerClass,1);
2*ones(nPerClass,1);
3*ones(nPerClass,1);
4*ones(nPerClass,1)];
% ---- 2) Standardize features (recommended for GNB)
muX = mean(X,1); sigX = std(X,0,1); sigX(sigX==0) = 1;
Xz = (X - muX) ./ sigX;
% ---- 3) Train/test split (70/30)
n = size(Xz,1);
idx = randperm(n);
nTr = round(0.7*n);
tr = idx(1:nTr);
te = idx(nTr+1:end);
Xtr = Xz(tr,:); ytr = y(tr);
Xte = Xz(te,:); yte = y(te);
% ---- 4) Train Gaussian Naive Bayes
model = gnb_train(Xtr, ytr);
% ---- 5) Predict
[yhat, scores] = gnb_predict(model, Xte);
% ---- 6) Accuracy and confusion matrix
acc = mean(yhat == yte);
fprintf('Test accuracy: %.2f%%\n', 100*acc);
Test accuracy: 98.89%
K = numel(unique(y));
CM = zeros(K);
for i = 1:numel(yte)
CM(yte(i), yhat(i)) = CM(yte(i), yhat(i)) + 1;
end
disp('Confusion matrix (rows = true, cols = predicted):');
Confusion matrix (rows = true, cols = predicted):
disp(CM);
41 1 0 0 0 45 0 0 0 0 52 0 0 1 0 40
% ---- 7) (Optional) quick scatter to visualize test set predictions
figure; hold on;
clsColors = lines(4); % 4 distinct colors
for c = 1:4
pts = (yte == c);
scatter(Xte(pts,1), Xte(pts,2), 25, clsColors(c,:), 'filled', ...
'MarkerFaceAlpha', 0.7, 'DisplayName', sprintf('True C%d', c));
end
for c = 1:4
pts = (yhat == c);
scatter(Xte(pts,1), Xte(pts,2), 10, 'k', 'o'); % black rings for predicted class
end
title(sprintf('Gaussian NB on 4 classes (Test acc = %.1f%%)', 100*acc));
xlabel('z-score feature 1'); ylabel('z-score feature 2'); grid on; box on; hold of
You can refer to the following MATLAB documentations for more details:
  1. classification: https://www.mathworks.com/help/stats/classification.html
  2. confusionchart: https://in.mathworks.com/help/stats/confusionchart.html

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!