clear all;
clc;
trainNow = false;
if trainNow == true
numFramesPerModType = 100;
else
numFramesPerModType = 100;
end
percentTrainingSamples = 80;
percentValidationSamples = 10;
percentTestSamples = 10;
sps = 8;
spf = 1024;
symbolsPerFrame = spf / sps;
fs = 200e3;
fc = [902e6 100e6];
SNR = 30;
std = sqrt(10.^(-SNR/10))
awgnChannel = comm.AWGNChannel(...
'NoiseMethod', 'Signal to noise ratio (SNR)', ...
'SignalPower', 1, ...
'SNR', SNR);
multipathChannel = comm.RicianChannel(...
'SampleRate', fs, ...
'PathDelays', [0 1.8 3.4]/fs, ...
'AveragePathGains', [0 -2 -10], ...
'KFactor', 4, ...
'MaximumDopplerShift', 4);
maxDeltaOff = 5;
deltaOff = (rand()*2*maxDeltaOff) - maxDeltaOff;
C = 1 + (deltaOff/1e6);
offset = -(C-1)*fc(1);
frequencyShifter = comm.PhaseFrequencyOffset(...
'SampleRate', fs, ...
'FrequencyOffset', offset);
channel = helperModClassTestChannel(...
'SampleRate', fs, ...
'SNR', SNR, ...
'PathDelays', [0 1.8 3.4] / fs, ...
'AveragePathGains', [0 -2 -10], ...
'KFactor', 4, ...
'MaximumDopplerShift', 4, ...
'MaximumClockOffset', 5, ...
'CenterFrequency', 902e6)
chInfo = info(channel);
rng(1235)
tic
modulationTypes = categorical(["BPSK", "QPSK", "8PSK", ...
"16QAM", "64QAM", "PAM4", "GFSK", "CPFSK", ...
"B-FM", "DSB-AM", "SSB-AM"]);
numModulationTypes = length(modulationTypes);
channelInfo = info(channel);
frameStore = helperModClassFrameStore(...
numFramesPerModType*numModulationTypes,spf,modulationTypes);
transDelay = 50;
for modType = 1:numModulationTypes
fprintf('%s - Generating %s frames\n', ...
datestr(toc/86400,'HH:MM:SS'), modulationTypes(modType))
numSymbols = (numFramesPerModType / sps);
dataSrc = getSource(modulationTypes(modType), sps, 2*spf, fs);
modulator = getModulator(modulationTypes(modType), sps, fs);
if contains(char(modulationTypes(modType)), {'B-FM','DSB-AM','SSB-AM'})
channel.CenterFrequency = 100e6;
else
channel.CenterFrequency = 902e6;
end
for p=1:numFramesPerModType
x = dataSrc();
y = modulator(x);
rxSamples = channel(y);
frame = helperModClassFrameGenerator(rxSamples, spf, spf, transDelay, sps);
add(frameStore, frame, modulationTypes(modType));
end
end
[mcfsTraining,mcfsValidation,mcfsTest] = splitData(frameStore,...
[percentTrainingSamples,percentValidationSamples,percentTestSamples]);
[rxTraining,rxTrainingLabel] = get(mcfsTraining);
[rxValidation,rxValidationLabel] = get(mcfsValidation);
[rxTest,rxTestLabel] = get(mcfsTest);
size(rxTraining)
plotTimeDomain(rxTest,rxTestLabel,modulationTypes,fs)
plotSpectrogram(rxTest,rxTestLabel,modulationTypes,fs,sps)
figure
subplot(3,1,1)
histogram(rxTrainingLabel)
title("Training Label Distribution")
subplot(3,1,2)
histogram(rxValidationLabel)
title("Validation Label Distribution")
subplot(3,1,3)
histogram(rxTestLabel)
title("Test Label Distribution")
dropoutRate = 0.5;
numModTypes = numel(modulationTypes);
netWidth = 1;
filterSize = [1 sps];
poolSize = [1 2];
modClassNet = [
imageInputLayer([2 spf 1], 'Normalization', 'none', 'Name', 'Input Layer')
convolution2dLayer(filterSize, 16*netWidth, 'Padding', 'same', 'Name', 'CNN1')
batchNormalizationLayer('Name', 'BN1')
reluLayer('Name', 'ReLU1')
maxPooling2dLayer(poolSize, 'Stride', [1 2], 'Name', 'MaxPool1')
convolution2dLayer(filterSize, 24*netWidth, 'Padding', 'same', 'Name', 'CNN2')
batchNormalizationLayer('Name', 'BN2')
reluLayer('Name', 'ReLU2')
maxPooling2dLayer(poolSize, 'Stride', [1 2], 'Name', 'MaxPool2')
convolution2dLayer(filterSize, 32*netWidth, 'Padding', 'same', 'Name', 'CNN3')
batchNormalizationLayer('Name', 'BN3')
reluLayer('Name', 'ReLU3')
maxPooling2dLayer(poolSize, 'Stride', [1 2], 'Name', 'MaxPool3')
convolution2dLayer(filterSize, 48*netWidth, 'Padding', 'same', 'Name', 'CNN4')
batchNormalizationLayer('Name', 'BN4')
reluLayer('Name', 'ReLU4')
maxPooling2dLayer(poolSize, 'Stride', [1 2], 'Name', 'MaxPool4')
convolution2dLayer(filterSize, 64*netWidth, 'Padding', 'same', 'Name', 'CNN5')
batchNormalizationLayer('Name', 'BN5')
reluLayer('Name', 'ReLU5')
maxPooling2dLayer(poolSize, 'Stride', [1 2], 'Name', 'MaxPool5')
convolution2dLayer(filterSize, 96*netWidth, 'Padding', 'same', 'Name', 'CNN6')
batchNormalizationLayer('Name', 'BN6')
reluLayer('Name', 'ReLU6')
averagePooling2dLayer([1 ceil(spf/32)], 'Name', 'AP1')
fullyConnectedLayer(numModTypes, 'Name', 'FC1')
softmaxLayer('Name', 'SoftMax')
classificationLayer('Name', 'Output') ]
maxEpochs = 12;
miniBatchSize = 256;
validationFrequency = floor(numel(rxTrainingLabel)/miniBatchSize);
options = trainingOptions('sgdm', ...
'InitialLearnRate',2e-2, ...
'MaxEpochs',maxEpochs, ...
'MiniBatchSize',miniBatchSize, ...
'Shuffle','every-epoch', ...
'Plots','training-progress', ...
'Verbose',false, ...
'ValidationData',{rxValidation,rxValidationLabel}, ...
'ValidationFrequency',validationFrequency, ...
'LearnRateSchedule', 'piecewise', ...
'LearnRateDropPeriod', 9, ...
'LearnRateDropFactor', 0.1, ...
'ExecutionEnvironment', 'gpu');
if trainNow == true
fprintf('%s - Training the network\n', datestr(toc/86400,'HH:MM:SS'))
trainedNet = trainNetwork(rxTraining,rxTrainingLabel,modClassNet,options);
else
load trainedModulationClassificationNetwork
end
fprintf('%s - Classifying test frames\n', datestr(toc/86400,'HH:MM:SS'))
rxTestPred = classify(trainedNet,rxTest);
testAccuracy = mean(rxTestPred == rxTestLabel);
disp("Test accuracy: " + testAccuracy*100 + "%")
figure
cm = confusionchart(rxTestLabel, rxTestPred);
cm.Title = 'Confusion Matrix for Test Data';
cm.RowSummary = 'row-normalized';
cm.Parent.Position = [cm.Parent.Position(1:2) 740 424];
function modulator = getModulator(modType, sps, fs)
switch modType
case "BPSK"
modulator = @(x)bpskModulator(x,sps);
case "QPSK"
modulator = @(x)qpskModulator(x,sps);
case "8PSK"
modulator = @(x)psk8Modulator(x,sps);
case "16QAM"
modulator = @(x)qam16Modulator(x,sps);
case "64QAM"
modulator = @(x)qam64Modulator(x,sps);
case "GFSK"
modulator = @(x)gfskModulator(x,sps);
case "CPFSK"
modulator = @(x)cpfskModulator(x,sps);
case "PAM4"
modulator = @(x)pam4Modulator(x,sps);
case "B-FM"
modulator = @(x)bfmModulator(x, fs);
case "DSB-AM"
modulator = @(x)dsbamModulator(x, fs);
case "SSB-AM"
modulator = @(x)ssbamModulator(x, fs);
end
end
function src = getSource(modType, sps, spf, fs)
switch modType
case {"BPSK","GFSK","CPFSK"}
M = 2;
src = @()randi([0 M-1],spf/sps,1);
case {"QPSK","PAM4"}
M = 4;
src = @()randi([0 M-1],spf/sps,1);
case "8PSK"
M = 8;
src = @()randi([0 M-1],spf/sps,1);
case "16QAM"
M = 16;
src = @()randi([0 M-1],spf/sps,1);
case "64QAM"
M = 64;
src = @()randi([0 M-1],spf/sps,1);
case {"B-FM","DSB-AM","SSB-AM"}
src = @()getAudio(spf,fs);
end
end
function x = getAudio(spf,fs)
persistent audioSrc audioRC
if isempty(audioSrc)
audioSrc = dsp.AudioFileReader('audio_mix_441.wav',...
'SamplesPerFrame',spf,'PlayCount',inf);
audioRC = dsp.SampleRateConverter('Bandwidth',30e3,...
'InputSampleRate',audioSrc.SampleRate,...
'OutputSampleRate',fs);
[~,decimFactor] = getRateChangeFactors(audioRC);
audioSrc.SamplesPerFrame = ceil(spf / fs * audioSrc.SampleRate / decimFactor) * decimFactor;
end
x = audioRC(audioSrc());
x = x(1:spf,1);
end
function frames = getNNFrames(rx,modType)
frames = helperModClassFrameGenerator(rx,1024,1024,32,8);
frameStore = helperModClassFrameStore(10,1024,categorical({modType}));
add(frameStore,frames,modType);
frames = get(frameStore);
end
function plotScores(score,labels)
co = [0.08 0.9 0.49;
0.52 0.95 0.70;
0.36 0.53 0.96;
0.09 0.54 0.67;
0.48 0.99 0.26;
0.95 0.31 0.17;
0.52 0.85 0.95;
0.08 0.72 0.88;
0.12 0.45 0.69;
0.22 0.11 0.49;
0.65 0.54 0.71];
figure; ax = axes('ColorOrder',co,'NextPlot','replacechildren');
bar(ax,[score; nan(2,11)],'stacked'); legend(categories(labels),'Location','best');
xlabel('Frame Number'); ylabel('Score'); title('Classification Scores')
end
function plotTimeDomain(rxTest,rxTestLabel,modulationTypes,fs)
numRows = ceil(length(modulationTypes) / 4);
spf = size(rxTest,2);
t = 1000*(0:spf-1)/fs;
if size(rxTest,1) == 2
IQAsRows = true;
else
IQAsRows = false;
end
for modType=1:length(modulationTypes)
subplot(numRows, 4, modType);
idxOut = find(rxTestLabel == modulationTypes(modType), 1);
if IQAsRows
rxI = rxTest(1,:,1,idxOut);
rxQ = rxTest(2,:,1,idxOut);
else
rxI = rxTest(1,:,1,idxOut);
rxQ = rxTest(1,:,2,idxOut);
end
plot(t,squeeze(rxI), '-'); grid on; axis equal; axis square
hold on
plot(t,squeeze(rxQ), '-'); grid on; axis equal; axis square
hold off
title(string(modulationTypes(modType)));
xlabel('Time (ms)'); ylabel('Amplitude')
end
end
function plotSpectrogram(rxTest,rxTestLabel,modulationTypes,fs,sps)
if size(rxTest,1) == 2
IQAsRows = true;
else
IQAsRows = false;
end
numRows = ceil(length(modulationTypes) / 4);
for modType=1:length(modulationTypes)
subplot(numRows, 4, modType);
idxOut = find(rxTestLabel == modulationTypes(modType), 1);
if IQAsRows
rxI = rxTest(1,:,1,idxOut);
rxQ = rxTest(2,:,1,idxOut);
else
rxI = rxTest(1,:,1,idxOut);
rxQ = rxTest(1,:,2,idxOut);
end
rx = squeeze(rxI) + 1i*squeeze(rxQ);
spectrogram(rx,kaiser(sps),0,1024,fs,'centered');
title(string(modulationTypes(modType)));
end
h = gcf; delete(findall(h.Children, 'Type', 'ColorBar'))
end
function flag = isPlutoSDRInstalled
spkg = matlabshared.supportpkg.getInstalled;
flag = ~isempty(spkg) && any(contains({spkg.Name},'ADALM-PLUTO','IgnoreCase',true));
end
function flag = isUSRPInstalled
spkg = matlabshared.supportpkg.getInstalled;
flag = ~isempty(spkg) && any(contains({spkg.Name},'USRP','IgnoreCase',true));
end
function y = bpskModulator(x,sps)
persistent filterCoeffs
if isempty(filterCoeffs)
filterCoeffs = rcosdesign(0.35, 4, sps);
end
syms = pskmod(x,2);
y = filter(filterCoeffs, 1, upsample(syms,sps));
end
function y = qpskModulator(x,sps)
persistent filterCoeffs
if isempty(filterCoeffs)
filterCoeffs = rcosdesign(0.35, 4, sps);
end
syms = pskmod(x,4,pi/4);
y = filter(filterCoeffs, 1, upsample(syms,sps));
end
function y = psk8Modulator(x,sps)
persistent filterCoeffs
if isempty(filterCoeffs)
filterCoeffs = rcosdesign(0.35, 4, sps);
end
syms = pskmod(x,8);
y = filter(filterCoeffs, 1, upsample(syms,sps));
end
function y = qam16Modulator(x,sps)
persistent filterCoeffs
if isempty(filterCoeffs)
filterCoeffs = rcosdesign(0.35, 4, sps);
end
syms = qammod(x,16,'UnitAveragePower',true);
y = filter(filterCoeffs, 1, upsample(syms,sps));
end
function y = qam64Modulator(x,sps)
persistent filterCoeffs
if isempty(filterCoeffs)
filterCoeffs = rcosdesign(0.35, 4, sps);
end
syms = qammod(x,64,'UnitAveragePower',true);
y = filter(filterCoeffs, 1, upsample(syms,sps));
end
function y = pam4Modulator(x,sps)
persistent filterCoeffs amp
if isempty(filterCoeffs)
filterCoeffs = rcosdesign(0.35, 4, sps);
amp = 1 / sqrt(mean(abs(pammod(0:3, 4)).^2));
end
syms = amp * pammod(x,4);
y = filter(filterCoeffs, 1, upsample(syms,sps));
end
function y = gfskModulator(x,sps)
persistent mod meanM
if isempty(mod)
M = 2;
mod = comm.CPMModulator(...
'ModulationOrder', M, ...
'FrequencyPulse', 'Gaussian', ...
'BandwidthTimeProduct', 0.35, ...
'ModulationIndex', 1, ...
'SamplesPerSymbol', sps);
meanM = mean(0:M-1);
end
y = mod(2*(x-meanM));
end
function y = cpfskModulator(x,sps)
persistent mod meanM
if isempty(mod)
M = 2;
mod = comm.CPFSKModulator(...
'ModulationOrder', M, ...
'ModulationIndex', 0.5, ...
'SamplesPerSymbol', sps);
meanM = mean(0:M-1);
end
y = mod(2*(x-meanM));
end
function y = bfmModulator(x,fs)
persistent mod
if isempty(mod)
mod = comm.FMBroadcastModulator(...
'AudioSampleRate', fs, ...
'SampleRate', fs);
end
y = mod(x);
end
function y = dsbamModulator(x,fs)
y = ammod(x,50e3,fs);
end
function y = ssbamModulator(x,fs)
y = ssbmod(x,50e3,fs);
end