Main Content

Apply Transfer Learning on PyTorch Model to Identify 5G and LTE Signals

Since R2025a

This example shows how to use a pretrained PyTorch® semantic segmentation neural network to identify 5G NR and LTE signals in MATLAB®. MATLAB generates and preprocesses standard-based wireless signals for training and testing.

Introduction

Computer vision uses the semantic segmentation technique to identify objects and their locations in an image or a video. In wireless signal processing, the objects of interest are wireless signals, and the locations of the objects are the frequency and time occupied by the signals. In this example, you apply the semantic segmentation technique to wireless signals to identify spectral content in a wideband spectrogram.

In this example, you:

  1. Set up your Python® environment to call PyTorch functions from MATLAB.

  2. Load a pretrained PyTorch network into MATLAB.

  3. Generate, augment, and preprocess 5G NR and LTE signals online to fine-tune the PyTorch model weights.

  4. Evaluate and visualize the network outputs in MATLAB.

Associated AI for Wireless Examples

You can use this example as part of a complete deep learning workflow:

Capture and label bandwidths

Train Network with labeled bandwidths

The following examples show how to train a semantic segmentation network to identify 5G NR and LTE signals in a wideband spectrogram.

Identify 5G and LTE signals using SDR

For more information on implementing AI workflows for wireless communications, see AI for Wireless.

Set Up Python Environment

Before running this example, set up the environment as explained in PyTorch Coexecution. Specify the full path of the Python executable to use in the executablePath field. The helperSetupPyenv function sets the Python environment in MATLAB according to the selected executionMode and checks that the libraries listed in the requirements_specsense.txt file are installed.

executionMode = "OutOfProcess";
reqFileName = "requirements_specsense.txt";

if ispc
    executablePath = ".venv\Scripts\pythonw.exe";
else
    executablePath = "venv/bin/python";
end

currentPyenv = helperSetupPyenv(executablePath,executionMode,reqFileName);
Setting up Python environment
Parsing requirements_specsense.txt 
Checking required package 'torch'
Checking required package 'torchvision'
Required Python libraries are installed.

Deeplabv3 Semantic Segmentation Network

This example uses the pretrained deeplabv3 semantic segmentation network. The example modifies the deeplabv3 to set the number of output classes to three: "Noise", "NR", and "LTE". Set the network backbone as resnet50 or resnet101.

classNames = ["Noise" "NR" "LTE"];
networkBackbone = "resnet50";

The deeplabv3nn.py module is a helper Python module that contains all the functions required to train and test the PyTorch neural network for this example. The PyTorch Wrapper Template sections shows how to create the interface functions using a template. Display short information on the model using the info function.

numClasses = length(classNames);
model_size = py.deeplabv3nnwrapper.info(networkBackbone,numClasses);
numLayers = single(model_size{1});
numLearnables = single(model_size{2})/1e6;

disp("Deeplabv3 model with "+networkBackbone+...
    " backbone has "+num2str(numLayers)+" layers and "+...
    num2str(numLearnables)+" M learnables.");
Deeplabv3 model with resnet50 backbone has 189 layers and 41.9948 M learnables.

Specify Data Set Parameters

This example follows the steps in Generate Training Data section in the Spectrum Sensing with Deep Learning to Identify 5G, LTE, and WLAN Signals example to generate a data set of 5G NR and LTE frames. Each sample of the data set contains a frame containing either a 5G NR signal, an LTE signal, or a combined 5G NR and LTE signal using randomly selected values from the standards parameter set specified below. Each frame passes through a fading channel model with randomly selected SNR and Doppler shift.

params.Fs = 61.44e6;          % Hz
params.NumSubFrames = 40;     % corresponds to 40 ms

% 5G NR Parameters
params.SCSVec = [15 30];
params.BandwidthVec = [10:5:30 40 50];     % [5:5:25 30:10:100]
params.MaxTimeShift = params.NumSubFrames; % Time shift in milliseconds
params.SSBPeriod = 20;                     % [5 10 20 40 80 160] 20 is most frequently found OTA

% LTE Parameters
params.RCVec = ["R.2","R.6","R.8","R.9"];
params.TrBlkOffVec = [1,2,3,4,5,6,7,8];

% Channel Parameters
params.SNRMin = 0;        % dB
params.SNRMax = 40;       % dB
params.DopplerMin = 0;    % Hz
params.DopplerMax = 500;  % Hz
params.CarrierFreq = 4e9; % Hz

To train a semantic segmentation model for spectrum sensing, generate the power spectral density (PSD) spectrogram of the impaired signal and convert it into a three-channel image. Each pixel in the spectrogram image belongs to one of three classes, that is, Noise, NR, or LTE. An advantageous loss function for handling pixel class imbalance when training semantic segmentation networks is the Dice loss function.

The BCE-Dice loss between the network output, Y, and a pixel label target, T, is defined as

loss=i=1NBCE(Yi,Ti)+diceloss(Yi,Ti), where

BCE(Yi,Ti)=-(Yi*log(Ti)+(1-Yi*log(1-Ti)))

diceloss(Yi,Ti)=1-(2*sum(Yi.*Ti)+1)(sum(Yi)+sum(Ti)+1)

The helperSpecSenseGenerateImagesMasks function generates a spectrogram image and the corresponding pixel mask as a binary 3-D array, where the first two dimensions are equal to the height and width of the the spectrogram image and the third dimension is the number of classes in the data set. Specify the spectrogram image parameters.

% Spectrogram Parameters
params.NOverlap = 10;
params.Window = hann(256);
params.Nfft = 4069;

% Image and mask parameters
params.ImageSize = [256 256];    % pixels
params.ClassNames = classNames;
params.NumStandards  = 2;

Train Deep Neural Network

In this section, you train the PyTorch Deeplabv3 model using the Python interface with run-time data generation and preprocessing in the training loop. During each training iteration, you:

  • Generate a set of 5G and LTE signals.

  • Augment each 5G/LTE signal with multiple fading channel realizations to create a miniBatch of training samples.

  • Preprocess the augmented time domain signals.

  • Train the model for one iteration.

Data augmentation is a technique that allows you to increase the size of a training data set by adding random impairments without generating or collecting new data signals. Augmenting the training data often improves the network generalization. In this example, the impairmentsPerSignal parameter specifies the ratio between the number of frames generated from standard to the number of channel-impaired signals. The default value of impairmentsPerSignal is set to 2, that is, each 5G NR/LTE frame is impaired with two randomly generated fading channel realizations. This figure shows the augmented data samples resulting from one 5G NR signal and one LTE signal.

DataAugmentation4.png

impairmentsPerSignal = 2;
miniBatch = 30;

Set trainNow to true to train the model. Training the model for 500 iterations with background data generation takes approximately 1 hour and 20 minutes using an Intel® Xeon® W-2133 CPU and an NVIDIA® RTX 3080 GPU. If trainNow is set to false, the example loads a pretrained model to run the example quickly.

trainNow = false;
plotLoss = false;
verbose = true;
valFreq = 10;
saveCheckPointFreq = 50;
maxTrainIter = 500;

params.MiniBatchSize = miniBatch;
params.ImpPerSig = impairmentsPerSignal;

Train Network with Background Data Generation

To train the PyTorch model with run-time data generation, in each training iteration, you generate a minibatch of augmented time-domain frames using helperSpecSenseGenerateSignals function. Then, you generate the spectrogram and masks of each time domain signal in the minibatch using helperSpecSenseGenerateImagesMasks function. Use the deeplabv3_object to compute the loss function value and update the PyTorch model weights.

if trainNow

    weightsFileName = "trained_specsenselabv3.pth"; %#ok
    deeplabv3Obj = py.deeplabv3nnwrapper.setup_trainer( ...
        networkBackbone, ...
        weightsFileName, ...
        saveCheckPointFreq, ...
        maxTrainIter);

backgroundDataGen is set to true by default to speed up the training. You use Parallel Computing Toolbox™ to generate and augment each minibatch of time domain signals using a background parallel pool. If the Parallel Computing Toolbox is not available, this section runs on a single CPU worker.

Use the helperGetDataBatch function to generate a batch of augmented signals using a background parallel pool. Use the CPU or the GPU, if one is available, to generate the spectrogram images and masks for the batch of augmented signals and train the model for one iteration.

DatagenAcc2.png

    backgroundDataGen = true;

    tstart = tic;
    if backgroundDataGen & canUseParallelPool

        [trainFuture,valFuture] = helperSetBackgroundDataGen(params,"train");

        for iteration = 1:maxTrainIter

            [trainFuture,images,masks] = helperGetDataBatch(params,trainFuture);
            py.deeplabv3nnwrapper.train_one_iteration(deeplabv3Obj,images,masks);

            if mod(iteration,valFreq)==0 || iteration == 1
                [valFuture,images,masks] = helperGetDataBatch(params,valFuture);
                py.deeplabv3nnwrapper.validate(deeplabv3Obj,images,masks);

                t = seconds(toc(tstart));
                t.Format = "hh:mm:ss";
                helperTrainingProgress(iteration,maxTrainIter,deeplabv3Obj, ...
                    verbose,plotLoss,valFreq,t);
            end

        end
    else

Train Network with In-Process Data Generation

Set backgroundDataGen to false to generate the data, preprocess the data, and train the network using the same process. Generating the data in-process increases the training time to approximately 4 hours and 30 minutes using an Intel Xeon W-2133 CPU and an NVIDIA RTX 3080 GPU.

        for iteration = 1:maxTrainIter
            
            [signals,freqLim,labels] = helperSpecSenseGenerateSignals(params);
            [images,masks] = helperSpecSenseGenerateImagesMasks(params,signals,freqLim,labels);

            py.deeplabv3nnwrapper.train_one_iteration(deeplabv3Obj,images,masks);

            if mod(iteration,valFreq)==0
                [signals,freqLim,labels] = helperSpecSenseGenerateSignals(params);
                [images,masks] = helperSpecSenseGenerateImagesMasks(params,signals,freqLim,labels);
                py.deeplabv3nnwrapper.validate(deeplabv3Obj,images, masks);

                t = seconds(toc(tstart));
                t.Format = "hh:mm:ss";
                helperTrainingProgress(iteration,maxTrainIter,deeplabv3Obj, ...
                    verbose,plotLoss,valFreq,t);
            end

        end
    end
else
    weightsFileName = "specsenselabv3.pth";
    helperDownloadFiles();
    deeplabv3Obj = py.deeplabv3nnwrapper.construct_model( ...
        networkBackbone, ...
        weightsFileName);
end
Starting download of data files from:
	https://www.mathworks.com/supportfiles/spc/SpectrumSenseTorchModel/SpectrumSenseLabV3.zip
Extracting files.
Extract complete.

Test Deep Neural Network

Test the trained PyTorch model using a new set of generated data with 300 impaired 5G NR and LTE signals. The example specifies the size of the test data set using the product of the minibatch and numTestIter variables.

numTestIter = 10;
backgroundDataGen = true;
numTestSamples  = miniBatch*numTestIter
numTestSamples = 
300
disp("Testing model ...")
Testing model ...
if backgroundDataGen

    testFuture = helperSetBackgroundDataGen(params,"test");

    for iteration = 1:numTestIter
        [testFuture,images,masks] = helperGetDataBatch(params,testFuture);
            py.deeplabv3nnwrapper.test(deeplabv3Obj,images,masks);
        if mod(iteration,2)==0
            disp("Tested "+num2str(iteration*miniBatch)+" signals ...")
        end
    end

else

    for iteration = 1:numTestIter  %#ok
        [signals,freqLim,labels] = helperSpecSenseGenerateSignals(params);
        [images,masks] = helperSpecSenseGenerateImagesMasks(params,signals,freqLim,labels);
            py.deeplabv3nnwrapper.test(deeplabv3Obj,images,masks);
        if mod(iteration,5)==0
            disp("Tested "+num2str(iteration*miniBatch)+" signals ...")
        end
    end

end
Tested 60 signals ...
Tested 120 signals ...
Tested 180 signals ...
Tested 240 signals ...
Tested 300 signals ...
disp("Testing model with "+num2str(numTestSamples)+" images complete.")
Testing model with 300 images complete.
    IOU = deeplabv3Obj.test_iou;
    Accuracy = deeplabv3Obj.test_accuracy;

numTrainSamples  = miniBatch*maxTrainIter;

resultsTable = table(categorical({'Deeplabv3_resnet50'}), ...
    numTrainSamples,numTestSamples,IOU,Accuracy, ...
    VariableNames=["Network" ...
    "Training Dataset Size" ...
    "Test Dataset Size" ...
    "Mean IOU" ...
    "Mean Accuracy"]);

disp(resultsTable)
         Network          Training Dataset Size    Test Dataset Size    Mean IOU    Mean Accuracy
    __________________    _____________________    _________________    ________    _____________

    Deeplabv3_resnet50            15000                   300           0.80555        99.472    

Visually examine the classification accuracy of the trained PyTorch network using randomly generated test data samples.

rng(10)
frameDuration = params.NumSubFrames*1e-3;    % seconds
params.MiniBatchSize = 3;
params.ImpPerSig = 1;

[signals,freqLim,labels] = helperSpecSenseGenerateSignals(params);
[images,masks] = helperSpecSenseGenerateImagesMasks(params,signals,freqLim,labels);

    predictions = single(py.deeplabv3nnwrapper.predict(deeplabv3Obj,images));

for idx = 1:3
    rcvdSpectrogram = squeeze(images(:,:,:,idx));
    trueLabels = categorical(ones(256),1:length(params.ClassNames),params.ClassNames);
    predictedLabels = categorical(ones(256),1:length(params.ClassNames),params.ClassNames);
    for i = 2:3
        trueLabels(squeeze(masks(:,:,i,idx))==1) = categorical(params.ClassNames(i));
        predictedLabels(squeeze(predictions(:,:,i,idx))==1) = categorical(params.ClassNames(i));
    end

    figure
    helperSpecSenseDisplayResults(im2uint8(rescale(rcvdSpectrogram)),trueLabels,predictedLabels, ...
        params.ClassNames,params.Fs,params.CarrierFreq,frameDuration);
end

Figure contains 3 axes objects. Axes object 1 with title Received Spectrogram, xlabel Frequency (MHz), ylabel Time (ms) contains an object of type image. Axes object 2 with title True signal labels, xlabel Frequency (MHz), ylabel Time (ms) contains an object of type image. Axes object 3 with title Estimated signal labels, xlabel Frequency (MHz), ylabel Time (ms) contains an object of type image.

Figure contains 3 axes objects. Axes object 1 with title Received Spectrogram, xlabel Frequency (MHz), ylabel Time (ms) contains an object of type image. Axes object 2 with title True signal labels, xlabel Frequency (MHz), ylabel Time (ms) contains an object of type image. Axes object 3 with title Estimated signal labels, xlabel Frequency (MHz), ylabel Time (ms) contains an object of type image.

Figure contains 3 axes objects. Axes object 1 with title Received Spectrogram, xlabel Frequency (MHz), ylabel Time (ms) contains an object of type image. Axes object 2 with title True signal labels, xlabel Frequency (MHz), ylabel Time (ms) contains an object of type image. Axes object 3 with title Estimated signal labels, xlabel Frequency (MHz), ylabel Time (ms) contains an object of type image.

Helper Files

  • helperdeeplabv3nn.py

  • helperdeeplabv3nnwrapper.py

  • helperinstalledlibs.py

  • helperLibraryChecker.m

  • helperSetupPyenv.m

  • helperSpecSenseDisplayResults.m

  • helperSpecSenseGenerateImagesMasks.m

  • helperSpecSenseGenerateSignals.m

  • helperSpecSenseLTESignal.m

  • helperSpecSenseNRSignal.m

PyTorch Wrapper Template

You can use your own PyTorch models in MATLAB using the Python interface. The py_wrapper_template.py file provides a simple interface with a predefined API. This example uses the following API set:

  • construct_model: returns the PyTorch neural network model

  • setup_trainer: sets up a trainer object for with online training

  • train_one_iteration: trains the PyTorch model for one iteration for online training

  • validate: validates the PyTorch model for online training

  • predict: runs the PyTorch model with the provided input(s)

  • test: Performs a test step with the given images and masks.

  • info: prints or returns information on the PyTorch model

  • process_data_from_matlab: Processes image and mask data from MATLAB format to torch tensors.

  • process_data_to_matlab: Processes network output from torch tensors to a format suitable for post processing in MATLAB.

The Train PyTorch Channel Prediction Models example shows an offline training workflow and uses the following API set in addition to the ones used in this example.

  • train: trains the PyTorch model

  • save_model_weights: saves the PyTorch model weights

  • load_model_weights: loads the PyTorch model weights

You can modify the py_wrapper_template.py file. Follow the instruction in the template file to implement the recommended entry points. Delete the entry points that are not relevant to your project. Use the entry point functions as shown in this example to use your own PyTorch models in MATLAB.

Local Functions

function varargout = helperSetBackgroundDataGen(params,status)
%helperSetBackgroundDataGen Sets a background pool and 
% starts the data generation in the background.

if ~isempty(gcp("nocreate"))
    p = gcp("nocreate");
else
    currentCluster = parcluster;
    p = parpool(min(currentCluster.NumWorkers,8));
end

if strcmp(status,"train")
    params.numWorkersTrain = floor(p.NumWorkers*0.8);
    numTrainFutures = max(params.numWorkersTrain*3,12);
    varargout{1}(1:numTrainFutures) = parallel.FevalFuture;
    for idx = 1:numTrainFutures
        varargout{1}(idx) = parfeval(@helperSpecSenseGenerateSignals,3,params);
    end

    params.numWorkersVal = max(floor(p.NumWorkers*0.1),1);
    numValFutures = params.numWorkersVal;
    varargout{2}(1:numValFutures) = parallel.FevalFuture;
    for idx = 1:numValFutures
        varargout{2}(idx) = parfeval(@helperSpecSenseGenerateSignals,3,params);
    end
else
    numTestFutures = p.NumWorkers;
    varargout{1}(1:numTestFutures) = parallel.FevalFuture;
    for idx = 1:numTestFutures
        varargout{1}(idx) = parfeval(@helperSpecSenseGenerateSignals,3,params);
    end
end

end

function [dataFuture,images,masks] = helperGetDataBatch(params,dataFuture)
%helperGetDataBatch Reads one batch of data from the background data queue.

[completedIdx,signals,freqLimits,labels] = fetchNext(dataFuture);
dataFuture(completedIdx) = parfeval(@helperSpecSenseGenerateSignals,3,params);
[images,masks] = helperSpecSenseGenerateImagesMasks(params,signals,freqLimits,labels);

end

function helperTrainingProgress(iter,maxTrainIter,deeplabv3_object, ...
    verbose,plotLoss,valFreq,t) %#ok
%helperTrainingProgress Prints and/or plots the training and validation loss.

if verbose
    trainLoss = deeplabv3_object.loss_vector{iter};
    outFcnLoss = deeplabv3_object.output_loss{iter};    

    disp(string(t)+ " - Iteration "+iter+"/"+maxTrainIter+":")
    fprintf("      Train diceBCE loss: %2.2f, IOU: %2.2f\n", ...
        trainLoss, outFcnLoss)

    valIter = max(iter/valFreq,1);

    fprintf("      Val dice loss: %.2f, val pixel accuracy: %.2f %%\n", ...
        deeplabv3_object.val_dice_loss{valIter}/valIter, ...
        deeplabv3_object.val_accuracy{valIter})

end

persistent h1 h2
if plotLoss
    if isempty(h1)
        figure;
        ax = gca;
        h1 = animatedline(ax,Color="red",Marker="o");
        h2 = animatedline(ax,Color="blue",Marker="*");
        title("Training Progress")
        xlabel("Iteration")
        ylabel("Training loss and IOU")
        legend("DiceBCE loss", "IOU")
        grid on
    end
    addpoints(h1,iter,deeplabv3_object.loss_vector{iter});
    addpoints(h2,iter,deeplabv3_object.output_loss{iter});
    drawnow
end
end

function helperDownloadFiles()
%helperDownloadFiles Download pretrained model weights.

targetDir = "SpectrumSenseTorchModel";
downloadFileName = "SpectrumSenseLabV3.zip";
expFileNames = ["license.txt","specsenselabv3.pth"];
downloadData{1} = struct("DownloadFile",downloadFileName, ...
    "ExpectedFiles",expFileNames);
dstFolder = pwd;

helperDownloadExtractFile(targetDir, downloadData, dstFolder)

end

See Also

Topics