Main Content

Train PyTorch Channel Prediction Models

Since R2025a

This example shows how to train a PyTorch™ based channel prediction neural network using data that you generate in MATLAB.

While this example demonstrates the use of PyTorch for training a channel prediction neural network, the Deep Learning Toolbox provides robust tools for implementing similar models directly within MATLAB.

Introduction

Wireless channel prediction is a crucial aspect of modern communication systems, enabling more efficient and reliable data transmission. Recent advancements in machine learning, particularly neural networks, have introduced a data-driven approach to wireless channel prediction. This approach does not rely on predefined models but instead learns directly from historical channel data. As a result, neural networks can adapt to realistic data, making them less sensitive to disturbances and interference.

Channel prediction using neural networks is fundamentally a time series learning problem since it involves forecasting future channel states based on past estimations. This method is particularly advantageous in environments where spatial correlation is minimal or absent, such as crowded urban areas with numerous moving objects. By focusing on temporal correlations and historical data, neural networks provide a computationally efficient and scalable solution across various environments.

Unlike single‑tap or isotropic scattering models, CDL channels exhibit clustered multipath dynamics and non‑isotropic Doppler spectra, leading to temporal correlations that vary across delay clusters. This motivates the use of gated recurrent unit (GRU) networks, whose nonlinear gates can selectively emphasize relevant temporal structures that linear predictors (e.g., LMMSE/Wiener filters) cannot exploit [1],[2].

PyTorch Code

In this example, you train a GRU network defined in PyTorch. The nr_channel_predictor.py file contains the neural network definition, training and other functionality for the PyTorch network. First create a Python wrapper for the functionality provided in the nr_channel_predictor.py file

The nr_channel_predictor_wrapper.py file contains the interface functions that minimize data transfer between MATLAB and Python processes. This example utilizes the following functions in the nr_channel_predictor_wrapper.py file:

  • construct_model: Constructs and optionally loads a PyTorch model for channel prediction,

  • train: Trains a channel predictor PyTorch model using offline training,

  • predict: Generates predictions using a trained PyTorch model and input data,

  • save: Saves the state dictionary of a PyTorch model to a file,

  • load: Loads a state dictionary into a PyTorch model from a specified file.

The PyTorch Wrapper Template section shows how to create the interface functions using a template.

The first step of designing an AI-based system is to prepare training and testing data. This example follows the Preprocess Data for AI-Based CSI Feedback Compression example that shows how to preprocess the channel estimates.

Load the preprocessed channel estimates data. If you have run the previous step, then the example uses the data that you prepared in the previous step. Otherwise, the example prepares the data as shown in the Preprocess Data for AI-Based CSI Feedback Compression example.

Generating about 110k samples of training and validation data requires 10 frames of channel realization and takes about 50 seconds using Parallel Computing Toolbox® and a six core Intel® Xeon® W-2133 CPU @ 3.60GHz.

horizon = 10; % ms
maxDoppler = 5;
if ~exist("inputData","var") || ~exist("targetData","var") || ~exist("dataOptions","var") || ~exist("channel","var") || ~exist("carrier","var")
  numFrames = 10;
  useParallel = false;
[inputData,targetData,dataOptions,systemParams,channel,carrier] = prepareData(numFrames,useParallel,horizon,maxDoppler);
end
Starting channel realization generation
1 worker(s) running
00:00:09 - 10% Completed
00:00:20 - 20% Completed
00:00:30 - 30% Completed
00:00:45 - 40% Completed
00:00:58 - 50% Completed
00:01:09 - 60% Completed
00:01:21 - 70% Completed
00:01:32 - 80% Completed
00:01:43 - 90% Completed
00:01:56 - 100% Completed
00:01:56 - 100% Completed
Starting CSI data preprocessing
1 worker(s) running
00:00:05 - 20% Completed
00:00:11 - 30% Completed
00:00:16 - 40% Completed
00:00:20 - 50% Completed
00:00:24 - 60% Completed
00:00:29 - 70% Completed
00:00:35 - 80% Completed
00:00:39 - 90% Completed
00:00:43 - 100% Completed
00:00:48 - 110% Completed
00:00:48 - 100% Completed

See the channel and carrier variables for current channel and carrier configurations. The inputData variable contains Nsamples samples of 2Ntx-by-Nseq arrays, where Ntx is the number of transmit antennas and Nseq is the number of consecutive slot-spaced time samples.

[Ntxiq,Nseq,Nsamples] = size(inputData)
Ntxiq = 
16
Nseq = 
68
Nsamples = 
124800

Permute the data to bring batches to the first dimension as the PyTorch networks expect batch to be the first dimension.

inputDataPerm = permute(inputData,[3,2,1]);
targetDataPerm = permute(targetData,[2,1]);

Separate the data into training and validation. Define the number of training and validation samples.

numTraining = 90000;
numValidation = 10000;

Randomly sample the input and target data on the time dimension to select training and validation samples. Since each 2Ntx-by-Nseq sample is independent, this case has no time continuity requirement.

idxRand = randperm(size(targetDataPerm,1));

Select training and validation data.

xTraining = inputDataPerm(idxRand(1:numTraining),:,:);
xValidation = inputDataPerm(idxRand(1+numTraining:numValidation+numTraining),:,:);
yTraining = targetDataPerm(idxRand(1:numTraining),:);
yValidation = targetDataPerm(idxRand(1+numTraining:numValidation+numTraining),:);

Set Up Python Environment

Before running this example, set up the Python environment as explained in Call Python from MATLAB for Wireless. Specify the full path of the Python executable to use in the pythonPath field below. The helperSetupPyenv function sets the Python environment in MATLAB according to the selected options and checks that the libraries listed in the requirements_chanpre.txt file are installed. This example is tested with Python version 3.11.

if ispc
  pythonPath = ".\.venv\Scripts\pythonw.exe";
else
  pythonPath = "./venv_linux/bin/python3";
end
requirementsFile = "requirements_chanpre.txt";
executionMode = "OutOfProcess";
currentPyenv = helperSetupPyenv(pythonPath,executionMode,requirementsFile);
Setting up Python environment
Parsing requirements_chanpre.txt 
Checking required package 'numpy'
Checking required package 'torch'
Required Python libraries are installed.

You can use the following process ID and name to attach a debugger to the Python interface and debug the example code.

fprintf("Process ID for '%s' is %s.\n", ...
currentPyenv.ProcessName,currentPyenv.ProcessID)
Process ID for 'MATLABPyHost' is 47596.

Preload the Python module for faster start.

module = py.importlib.import_module('nr_channel_predictor_wrapper');

Initiate Neural Network

Initialize the channel predictor neural network. Set GRU hidden size to 128 and number of hidden GRU units to 2. Layer normalization stabilizes training across Doppler conditions, while a dropout rate of 0.3 mitigates overfitting to specific cluster realizations. The chanPredictor variable is the PyTorch model for the GRU based channel predictor.

gruHiddenSize = 128;
gruNumLayers  = 2;
channelInfo = info(channel);
Ntx = channelInfo.NumTransmitAntennas
Ntx = 
8
chanPredictor = py.nr_channel_predictor_wrapper.construct_model(...
  Ntx, ...
  gruHiddenSize, ...
  gruNumLayers);

py.nr_channel_predictor_wrapper.info(chanPredictor)
Model architecture:
ChannelPredictorGRU(
  (gru): GRU(16, 128, num_layers=2, batch_first=True, dropout=0.3)
  (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (fc): Linear(in_features=128, out_features=16, bias=True)
)

Total number of parameters: 157456

Train Neural Network

The nr_channel_predictor_wrapper.py file contains the MATLAB interface functions to train the channel predictor neural network. Set values for hyperparameters number of epochs, batch size, initial learning rate, validation frequency, and validation patience in epochs. Use early stopping with patience 6 to avoid overfitting, which is especially relevant because CDL realizations contain time‑varying cluster reappearance patterns. Call the train function with required inputs to train and validate the chanPredictor model. Set the verbose variable to true to print out training progress. Training for a maximum epochs of 2000 with early stopping takes more than one hour on a PC that has NVIDIA® TITAN V GPU with a compute capability of 7.0 and 12 GB memory. Set trainNow to true by clicking the check box to train the network. If your GPU runs out of memory during training, reduce the batch size.

trainNow = false;
if trainNow
  numEpochs            =2000;
  batchSize            = 512;
  initialLearningRate  = 3e-3;
  validationFrequency  = 5;
  validationPatience   = 6;
  verbose              = true;
  tStart = tic;
  result = py.nr_channel_predictor_wrapper.train( ...
    chanPredictor, ...
    xTraining, ...
    yTraining, ...
    xValidation, ...
    yValidation, ...
    initialLearningRate, ...
    batchSize, ...
    numEpochs, ...
    validationFrequency, ...
    validationPatience, ...
    verbose);
  et = toc(tStart); et = seconds(et); et.Format = "hh:mm:ss.SSS";

The output of the train Python function is a cell array with five elements. The output contains the following in order:

  • Trained PyTorch model

  • Training loss array (per iteration)

  • Validation loss array (per epoch)

  • Outdated error (with respect to current channel estimate)

  • Time spent in Python

Parse the function output and display the results.

  chanPredictor = result{1};
  trainingLoss = single(result{2});
  validationLoss = single(result{3});
  elapsedTimePy = result{4};
  bestValidationLoss = result{5}
  bestEpoch = result{6}
  finalEpoch = result{7}
  etInPy = seconds(elapsedTimePy);
  etInPy.Format="hh:mm:ss.SSS";

Save the network for future use together with the training information.

  modelFileName = sprintf("chanpre_gru_hor%d_epochs%d_ts%s",dataOptions.Horizon, ...
    numEpochs,string(datetime("now","Format","dd_MM_HH_mm")));
  fileName = py.nr_channel_predictor_wrapper.save( ...
    chanPredictor, ...
    modelFileName, ...
    Ntx, ...
    gruHiddenSize, ...
    gruNumLayers, ...
    initialLearningRate, ...
    batchSize, ...
    numEpochs, ...
    validationFrequency);
  infoFileName = modelFileName+"_info";
  save(infoFileName,"dataOptions","trainingLoss","validationLoss", ...
    "etInPy","et","initialLearningRate","batchSize","numEpochs","validationFrequency", ...
    "Ntx","gruHiddenSize","gruNumLayers");
  fprintf("Saved network in '%s' file and\nnetwork info in '%s.mat' file.\n", ...
    string(fileName),infoFileName)
else

When called with a filename as the last input, the construct_model function creates a neural network and loads the trained weights from the file. Run the network with xValidation input by calling the predict Python function.

  numEpochs = 2000;
  horizon = 10;
  modelFileName = sprintf("channel_predictor_gru_horizon%d_epochs%d.pth",horizon,numEpochs);
  infoFileName = sprintf("channel_predictor_gru_horizon%d_epochs%d_info.mat",horizon,numEpochs);
  chanPredictor = py.nr_channel_predictor_wrapper.construct_model( ...
    Ntx, ...
    gruHiddenSize, ...
    gruNumLayers, ...
    modelFileName);
  y_out = py.nr_channel_predictor_wrapper.predict( ...
    chanPredictor, ...
    xValidation);

Calculate the mean square error (MSE) loss as compared to the expected channel estimates.

  y = single(y_out);
  bestValidationLoss = mean(sum(abs(y - yValidation).^2,2)/Ntx);

Load the training and validation loss logged during training.

  load(infoFileName,"validationLoss","trainingLoss","etInPy","et")
end
fprintf("Validation Loss: %f dB",10*log10(bestValidationLoss))
Validation Loss: -36.344093 dB

The overhead caused by the Python interface is insignificant.

fprintf("Total training time: %s\nTraining time in Python: %s\nOverhead: %s\n",et,etInPy,et-etInPy)
Total training time: 00:16:14.959
Training time in Python: 00:16:13.025
Overhead: 00:00:01.933

Plot the training and validation loss. As the number of iterations increases, the loss value converges to about -37 dB.

figure
plot(10*log10(trainingLoss));
hold on
numIters = size(trainingLoss,2);
iterPerEpoch = numIters/length(validationLoss);
plot(iterPerEpoch:iterPerEpoch:numIters,10*log10(validationLoss),"*-");
hold off
legend("Training", "Validation")
xlabel(sprintf("Iteration (%d iterations per epoch)",iterPerEpoch))
ylabel("Loss (dB)")
title("Training Performance (NMSE as Loss)")
grid on

Figure contains an axes object. The axes object with title Training Performance (NMSE as Loss), xlabel Iteration (880 iterations per epoch), ylabel Loss (dB) contains 2 objects of type line. These objects represent Training, Validation.

Investigate Network Performance

Test the network for different horizon values. The helperChanEstCompareNetworks function trains and tests the GRU channel prediction network for the horizon values specified in the horizonVec variable. For robustness, train the GRU with five different random initialization, and pick the model achieving the lowest validation loss.

trainForComparisonNow = false;
if trainForComparisonNow
  horizonVec = [1 2:4:90];
  gruHiddenSize        = 128;
  gruNumLayers         = 2;
  numEpochs            = 2000;
  batchSize            = 512;
  initialLearningRate  = 3e-3;
  validationFrequency  = 5;
  validationPatience   = 6;

  if validationFrequency > numEpochs
    error("numEpocs is less than validationFrequency. Increase " + ...
      "numEpochs or reduce validationFrequency to collect data.")
  end
  compTable = helperChanPreCompareNetworks(channel,carrier, ...
    dataOptions.InputSequenceLength,horizonVec,gruHiddenSize, ...
    gruNumLayers,numTraining,numValidation,numEpochs,batchSize, ...
    initialLearningRate,validationFrequency,validationPatience);
  save dChannelPredictionNetworkHorizonResults_trials compTable horizonVec numEpochs numTraining numValidation
else
  load dChannelPredictionNetworkHorizonResults compTable horizonVec numEpochs numTraining numValidation
end

The plotValidationLoss function plots the simulated validation loss values for all three network architectures. As the prediction horizon increases, the validation loss (NMSE) also increases, reflecting the reduced temporal correlation of the fading process at longer look‑ahead times. In CDL channels, each delay cluster corresponds to a superposition of Doppler components rather than a single Doppler shift, producing multiple characteristic time scales [4]. The nonlinear gating of GRUs can selectively emphasize or suppress these Doppler components depending on the prediction horizon, causing the characteristic rise and oscillations in NMSE beyond ~30 ms, which are not visible in linear minimum mean square error (LMMSE) predictors [3]. Thus, the GRU not only predicts better for short horizons but also reveals physically meaningful temporal structure in CDL fading. Minor non‑monotonic fluctuations in NMSE are expected due to small optimization induced variations during training especially for very low NMSE values.

plotValidationLoss(compTable);

Figure contains an axes object. The axes object with title GRU Channel Predictor, xlabel Horizon (ms), ylabel Validation Loss contains an object of type line.

References

[1] W. Jiang and H. D. Schotten, "Recurrent Neural Network-Based Frequency-Domain Channel Prediction for Wideband Communications," 2019 IEEE 89th Vehicular Technology Conference (VTC2019-Spring), Kuala Lumpur, Malaysia, 2019, pp. 1-6, doi: 10.1109/VTCSpring.2019.8746352.

[2] O. Stenhammar, G. Fodor and C. Fischione, "A Comparison of Neural Networks for Wireless Channel Prediction," in IEEE Wireless Communications, vol. 31, no. 3, pp. 235-241, June 2024, doi: 10.1109/MWC.006.2300140.

[3] I. Goodfellow, Y. Bengio, and A. Courville, Deep Learning. Cambridge, MA: MIT Press, 2016.

[4] 3GPP, Study on channel model for frequencies from 0.5 to 100 GHz (Release 19), 3GPP TR 38.901, V19.1.0, Oct. 2025

PyTorch Wrapper Template

You can use your own PyTorch models in MATLAB using the Python interface. The py_wrapper_template.py file provides a simple interface with a predefined API. This example uses the following API set:

  • construct_model: returns the PyTorch neural network model

  • train: trains the PyTorch model

  • save: saves the PyTorch model weights and metadata

  • load: loads the PyTorch model weights

  • info: prints or returns information on the PyTorch model

The Online Training and Testing of PyTorch Model for CSI Feedback Compression example shows an online training workflow and uses the following API set in addition to the one used in this example.

  • setup_trainer: sets up a trainer object for with online training

  • train_one_iteration: trains the PyTorch model for one iteration for online training

  • validate: validates the PyTorch model for online training

  • predict: runs the PyTorch model with the provided input(s)

You can modify the py_wrapper_template.py file. Follow the instruction in the template file to implement the recommended entry points. Delete the entry points that are not relevant to your project. Use the entry point functions as shown in this example to use your own PyTorch models in MATLAB.

Local Functions

function [inputData,targetData,dataOptions,systemParams,channel,carrier,sdsChan] = prepareData(numFrames,useParallel,horizon,maxDoppler)
rng(123)
carrier = nrCarrierConfig;
nSizeGrid = 52;                                         % Number resource blocks (RB)
systemParams.SubcarrierSpacing = 15;  % 15, 30, 60, 120 kHz
carrier.NSizeGrid = nSizeGrid;
carrier.SubcarrierSpacing = systemParams.SubcarrierSpacing;
systemParams.DelayProfile = 'CDL-C';   % 'CDL-A',...,'CDL-E','TDL-A',...,'TDL-E'
systemParams.DelaySpread = 300e-9;     % s
%% Antenna Panel Configuration
systemParams.TransmitAntennaArray.NumPanels        = 1; % Number of transmit panels in horizontal dimension (Ng)
systemParams.TransmitAntennaArray.PanelDimensions  = [2 2]; % Number of columns and rows in the transmit panel (N1, N2)
systemParams.TransmitAntennaArray.NumPolarizations = 2; % Number of transmit polarizations
systemParams.ReceiveAntennaArray.NumPanels         = 1; % Number of receive panels in horizontal dimension (Ng)
systemParams.ReceiveAntennaArray.PanelDimensions   = [2 1];                % Number of columns and rows in the receive panel (N1, N2)
systemParams.ReceiveAntennaArray.NumPolarizations  = 2; % Number of receive polarizations
systemParams.TransmitAntennaArray.NTxAnts = systemParams.TransmitAntennaArray.NumPolarizations*systemParams.TransmitAntennaArray.NumPanels*prod(systemParams.TransmitAntennaArray.PanelDimensions);
systemParams.ReceiveAntennaArray.NRxAnts = systemParams.ReceiveAntennaArray.NumPolarizations*systemParams.ReceiveAntennaArray.NumPanels*prod(systemParams.ReceiveAntennaArray.PanelDimensions);
systemParams.MaximumDopplerShift = maxDoppler;      % Hz
systemParams.Carrier = carrier;
channel = createChannel(systemParams);

Tc = 0.423/systemParams.MaximumDopplerShift;
numerology = (systemParams.SubcarrierSpacing/15)-1;
Tslot = 1e-3 / 2^numerology;
coherenceTimeInSlots = Tc / Tslot;
sequenceLength = ceil(coherenceTimeInSlots*0.8);
numSlotsPerFrame = sequenceLength + horizon + 10;

saveData =  true;
dataDir = fullfile(pwd,"Data");
dataFilePrefix = "nr_channel_est";
resetChanel = true;
sdsChan = helper3GPPChannelRealizations(...
  numFrames, ...
  channel, ...
  carrier, ...
  UseParallel=useParallel, ...
  SaveData=saveData, ...
  DataDir=dataDir, ...
  dataFilePrefix=dataFilePrefix, ...
  NumSlotsPerFrame=numSlotsPerFrame, ...
  ResetChannelPerFrame=resetChanel);

SNR = 20;
[sdsPreprocessed,dataOptions] = helperPreprocess3GPPChannelData( ...
  sdsChan, ...
  TrainingObjective="prediction", ...
  AverageOverSlots=false, ...
  TruncateChannel=false, ...
  InputSequenceLength=sequenceLength, ...
  PredictionHorizon=horizon, ...
  AddNoise=true, ...
  SNR=SNR, ...
  DataComplexity="real (interleaved)", ...
  DataDomain="Frequency-Spatial (FS)", ...
  UseParallel=useParallel, ...
  SaveData=saveData);

data = readall(sdsPreprocessed);
inputCells = cellfun(@(C) C{1}, data, 'UniformOutput', false);
targetCells = cellfun(@(C) C{2}, data, 'UniformOutput', false);
inputData = cat(3, inputCells{:});
targetData = cat(2, targetCells{:});
featuresMax = max(inputData,[],[2 3]);
featuresMin = min(inputData,[],[2 3]);
dataMax = max(featuresMax);
dataMin = min(featuresMin);
inputData = (inputData-dataMin) / (dataMax-dataMin);
targetData = (targetData-dataMin) / (dataMax-dataMin);
dataOptions.Normalization = "min-max";
dataOptions.MinValue = dataMin;
dataOptions.MaxValue = dataMax;
dataOptions.Horizon = horizon;
end

function channel = createChannel(simParameters)
% Create and configure the propagation channel. If the number of antennas
% is 1, configure only 1 polarization, otherwise configure 2 polarizations.

numTxPol = 1 + (simParameters.TransmitAntennaArray.NTxAnts>1);
numRxPol = 1 + (simParameters.ReceiveAntennaArray.NRxAnts>1);

if contains(simParameters.DelayProfile,'CDL')

  % Create CDL channel
  channel = nrCDLChannel;

  % Tx antenna array configuration in CDL channel. The number of antenna
  % elements depends on the panel dimensions. The size of the antenna
  % array is [M,N,P,Mg,Ng]. M and N are the number of rows and columns in
  % the antenna array. P is the number of polarizations (1 or 2). Mg and
  % Ng are the number of row and column array panels respectively. Note
  % that N1 and N2 in the panel dimensions follow a different convention
  % and denote the number of columns and rows, respectively.
  txArray = simParameters.TransmitAntennaArray;
  M = txArray.PanelDimensions(2);
  N = txArray.PanelDimensions(1);
  Ng = txArray.NumPanels;

  channel.TransmitAntennaArray.Size = [M N numTxPol 1 Ng];
  channel.TransmitAntennaArray.ElementSpacing = [0.5 0.5 1 1]; % Element spacing in wavelengths
  channel.TransmitAntennaArray.PolarizationAngles = [-45 45];  % Polarization angles in degrees

  % Rx antenna array configuration in CDL channel
  rxArray = simParameters.ReceiveAntennaArray;
  M = rxArray.PanelDimensions(2);
  N = rxArray.PanelDimensions(1);
  Ng = rxArray.NumPanels;

  channel.ReceiveAntennaArray.Size = [M N numRxPol 1 Ng];
  channel.ReceiveAntennaArray.ElementSpacing = [0.5 0.5 1 1];  % Element spacing in wavelengths
  channel.ReceiveAntennaArray.PolarizationAngles = [0 90];     % Polarization angles in degrees

elseif contains(simParameters.DelayProfile,'TDL')
  channel = nrTDLChannel;
  channel.NumTransmitAntennas = simParameters.TransmitAntennaArray.NTxAnts;
  channel.NumReceiveAntennas = simParameters.ReceiveAntennaArray.NRxAnts;
else
  error('Channel not supported.')
end

% Configure common channel parameters: delay profile, delay spread, and
% maximum Doppler shift
channel.DelayProfile = simParameters.DelayProfile;
channel.DelaySpread = simParameters.DelaySpread;
channel.MaximumDopplerShift = simParameters.MaximumDopplerShift;

% Configure the channel to return the OFDM response
channel.ChannelResponseOutput = 'ofdm-response';

% Get information about the baseband waveform after OFDM modulation step
waveInfo = nrOFDMInfo(simParameters.Carrier);

% Update channel sample rate based on carrier information
channel.SampleRate = waveInfo.SampleRate;
end

function plotValidationLoss(compTable)
horizonVec = compTable.Horizon;
metric = "BestValidationLoss";
val = compTable{compTable.Model=="GRU",metric};
if ~isempty(val)
  plot(horizonVec,10*log10(val),"-*")
  grid on
  xlabel("Horizon (ms)")
  ylabel("Validation Loss")
  title("GRU Channel Predictor")
end
end

See Also

Topics