Main Content

Signal Source Separation Using W-Net Architecture

This example shows how to separate two mixed signal sources using a deep learning network. Source separation is a common and complex signal processing problem that finds use in audio, vibration analysis, and biomedical applications. It consists of separating the signal components of a signal mixture when only the mixture is available.

An important source separation problem consists of discerning fetal and maternal electrocardiogram (ECG) signals present in noninvasive measurements taken on the abdominal area of a pregnant patient. This is an important problem because, if solved correctly, it can allow physicians to monitor the fetal ECG with minimum risk. Fetal cardiac monitoring and assessment during pregnancy are used for the early detection of fetal cardiac conditions.

This example uses simulated noninvasive abdominal ECG measurements on pregnant patients to illustrate how to solve the difficult problem of separating the fetal ECG and maternal ECG signals using a deep network. The source separation deep learning architecture used in this example is not limited to ECG signals and can be used in many other applications.

FECGSYN Data Set

This example uses the FECGSYN PhysioNet data set [1], [2], which contains simulated adult and noninvasive fetal ECG signals. The data is generated using the FECGSYN simulator [3]. The simulator represents maternal and fetal hearts as punctual dipoles with different magnitudes and spatial positions. It obtains fetal–maternal mixtures by treating each abdominal signal and noise component as an individual source whose signal is propagated onto the observational points (electrodes). This database is able to provide separate waveform files for each signal source, making it ideal to test a source separation deep learning model.

The FECGSYN consists of simulated ECG signals corresponding to ten different subjects. For each subject, simulations produced a fetal ECG (fECG), a maternal ECG (mECG), and two noise sources, all sampled at a rate of 250 Hz for five minutes. The original data set repeats simulations five times for five different SNR levels, for 34 ECG channels or "electrodes", and for five different measurement scenarios or cases. In this example we use a subset of the data set and consider all ten subjects, a single channel (channel 19 from the original data set), four SNR levels (3, 6, 9, and 12 dB), and three different measurement cases labeled C0, C1, and C3. As mentioned before, the simulation was repeated over five iterations for each combination of subject, SNR value, and measurement case, yielding a total of 10 subjects × 4 SNRs × 3 cases × 5 iterations = 600 files. There are three different measurement cases:

  • Case 0 (C0) — Baseline ECG signals

  • Case 1 (C1) — Fetal movement + C0

  • Case 3 (C3) — Signals with varying maternal and fetal heart rates + Noise from uterine contractions

The data set contains one MAT-file for each combination of subject, SNR level, iteration, and measurement case. The filenames use the format Ij_Ck.mat, where j is the iteration number (1 to 5) and k is the measurement case identifier (0, 1, 3). Each MAT-file contains these variables:

  • mECG — Maternal ECG signal

  • fECG — Fetal ECG signal

  • mECG_QRS — QRS peak locations for the maternal ECG signal as annotated by an expert system

  • fECG_QRS — QRS peak locations for the fetal ECG signal as annotated by an expert system

  • noise1 — First noise source

  • noise2 — Second noise source

All signals have been bandpass filtered into the frequency range from 5 Hz to 90 Hz.

The abdominal ECG signal (aECG) for each file is computed as the following mixture:

aECG=mECG+fECG+noise1+noise2

The mECG_QRS and fECG_QRS variables contain QRS peak locations of the maternal and fetal ECG signals and can be used to validate the efficacy of a source separation algorithm to identify correct heartbeat locations in time.

This example uses the data from the first nine subjects to train a deep network and the data from the tenth subject to test the network performance. The training data size is about 1.15 GB, and the training of the deep learning network takes a few hours even when run on a GPU. If you want to skip downloading the training data and the training process, set the trainNetworkFlag flag to false. If the flag is set to false, the example downloads a pretrained network that can be used to perform source separation on the test data. The example always downloads the test data corresponding to subject 10.

trainNetworkFlag = true;

Download the train and test data sets using the downloadSupportFile function. The data will be unzipped to the tempdir directory. If you want the data at a different location, change trainingDatasetFolder and testDatasetFolder to the desired locations.

if trainNetworkFlag
    % Download training data set
    trainingDatasetZipFile = matlab.internal.examples.downloadSupportFile('SPT','data/fetal-ecg-source-separation-trainingData.zip');
    trainingDatasetFolder = fullfile(tempdir,'fetal-ecg-source-separation-trainingData');
    if ~exist(trainingDatasetFolder,'dir')
        unzip(trainingDatasetZipFile,trainingDatasetFolder);
    end    
end

% Download test data set
testDatasetZipFile = matlab.internal.examples.downloadSupportFile('SPT','data/fetal-ecg-source-separation-testData.zip');
testDatasetFolder = fullfile(tempdir,'fetal-ecg-source-separation-testData');
if ~exist(testDatasetFolder,'dir')
    unzip(testDatasetZipFile,testDatasetFolder);
end

Create a signal datastore to access the files in the test data set. Specify the names of the variables that you want the datastore to read from each file.

testDS = signalDatastore(testDatasetFolder,IncludeSubfolders=true, ...
    SignalVariableNames=["mECG" "fECG" "noise1" "noise2" "mECG_QRS" "fECG_QRS"]);

Plot the first 2048 samples of the ECG signals for case C1 and SNR of 3 dB. Overlay the annotated QRS peaks for each signal.

idx = contains(testDS.Files,fullfile("snr03dB","I1_C1.mat"));
sds3dBC1 = subset(testDS,idx);

data = preview(sds3dBC1);
[mECG,fECG,noise1,noise2,mECG_QRS,fECG_QRS] = data{:};

% Abdominal ECG mixture
aECG = mECG(1:2048) + fECG(1:2048) + noise1(1:2048) + noise2(1:2048);

figure
subplot(3,1,1)
plot(aECG)
xline(fECG_QRS(1:50),":",Color="#77AC30")
xline(mECG_QRS(1:50),":",Color="#D95319")
axis([0 2048 -0.6 1])
title("aECG (red = mECG QRS peaks, green = fECG QRS peaks)")

subplot(3,1,2)
plot(fECG(1:2048))
xline(fECG_QRS(1:50),":",Color="#77AC30")
axis([0 2048 -0.6 1])
title("fECG")

subplot(3,1,3)
plot(mECG(1:2048))
xline(mECG_QRS(1:50),":",Color="#D95319")
axis([0 2048 -0.6 1])
title("mECG")

Figure contains 3 axes objects. Axes object 1 with title aECG (red = mECG QRS peaks, green = fECG QRS peaks) contains 101 objects of type line, constantline. Axes object 2 with title fECG contains 51 objects of type line, constantline. Axes object 3 with title mECG contains 51 objects of type line, constantline.

Notice the large difference in scale between the mECG and fECG signals.

Prepare Training Data

This example uses the data from the first nine subjects to train a deep network and the data from the tenth subject to test the network performance. To train the network, each signal is broken into segments of 1024 samples for a total of 73 segments per signal. Set up a signal datastore that reads the ECG signals and noise realizations for subjects 1 to 9. Transform the datastore to obtain aECG, mECG, and fECG signal segments of length 1024 samples. Each read to trainDS returns 73 segments of length 1024 aECG, mECG, and mECG signals formatted using CBT (channel-batch-time) dimensions. In addition to segmenting the signals into 1024-sample segments, the transform function, getECGSegments, also normalizes each segment using the rescale function to bring the signal levels to between –1 and 1. The rescaled segments are then centered using their median value.

segmentLength = 1024;
if trainNetworkFlag    
    trainDS = signalDatastore(trainingDatasetFolder,IncludeSubfolders=true,SignalVariableNames=["mECG" "fECG" "noise1" "noise2"]);      
    trainDS = transform(trainDS,@(d,f)getECGSegments(d,segmentLength));
end

To speed up training, read all the training data into memory so that the signal segmentation and normalization happens only once. If you have a Parallel Computing Toolbox™ license, use the UseParallel parameter so that the read operations are done in parallel. Create an array datastore to iterate through the training signal segments.

if trainNetworkFlag
    trainData = readall(trainDS,UseParallel=true);
    trainDS = arrayDatastore(trainData,OutputType="same");
end

W-Net Architecture for Source Separation

This example uses a so-called W-Net architecture to perform source separation [4]. W-Net consists of two U-Net autoencoders [5] that have been modified to operate on 1-D signal inputs. A U-Net autoencoder is a deep network that encodes signal features reducing its size at each step and then decodes the features to recreate the original input signal. You can think of the encoder branch of the autoencoder as a feature extraction branch. The main idea of the W-Net architecture is to have one auto encoder to reproduce an fECG signal (fECG autoencoder) and another to reproduce an mECG signal (mECG autoencoder) when the input to the autoencoders is set to an aECG mixture. The connection between the two autoencoders happens in the encoding branches. You subtract the features obtained by the mECG autoencoder from the features obtained by the fECG autoencoder, effectively achieving separation of the mECG component from the aECG input and yielding the desired separated fECG signal. This figure shows the architecture in detail.

Following reference [4], for the ECG source separation problem at hand set the filter size of the 1D convolutional layers to 4 for the fECG side and 35 for the mECG side. The number of filters used at the input 1D convolutional layers, N in the figure above, is set to 16. The input size, P in the figure above, has already been described as 1024. Create the W-Net network architecture using the createWNet function.

if trainNetworkFlag    
    filterSize_fECG = 4;
    filterSize_mECG = 35;
    numFilters_fECG = 16;
    numFilters_mECG = 16;
    lgraph = createWNet(segmentLength,filterSize_fECG,numFilters_fECG,filterSize_mECG,numFilters_mECG);
    wNet = dlnetwork(lgraph);
end

Training Loop

You need a training loop to train the W-Net model because you need to define a loss that combines the losses of the fECG and mECG branches of the network. The modelLoss function computes the training loss as the weighted sum of the mean absolute deviation between the actual and predicted ECG signals:

loss=fECGWeight×mean(|fECGactual-fECGpredicted|)+mECGWeight×mean(|mECGactual-mECGpredicted|);

Set fECGWeight to a value greater than mECGWeight to reflect the fact that the primary signals of interest are the fetal ECGs.

Use an Adam optimizer to update the network learnable parameters and specify an initial learn rate, a decay factor, the number of epochs, and the mini-batch size. The minibatchqueue outputs miniBatchSize batches of aECG, mECG, and fECG signal segments.

Due to the large size of the data set, the training process may take several hours. If your machine has a GPU and Parallel Computing Toolbox™, set the useGPUflag flag to true to speed up the training process.

useGPUflag = true;

if trainNetworkFlag    
    NumEpochs = 100;
    miniBatchSize = 512;
    learnRate = 0.0005;
    decay = 0.25;
    mECGWeight = 0.25;
    fECGWeight = 0.75;

    mbqTrain = minibatchqueue(trainDS, 3, ...
        MiniBatchSize=miniBatchSize,...
        MiniBatchFormat={'CBT','CBT','CBT'}, ...
        MiniBatchFcn=@processMB, ...
        DispatchInBackground=true);

    if useGPUflag
        mbqTrain.OutputEnvironment = "gpu";
    end

    % Initialize some training loop variables
    trailingAvg = [];
    trailingAvgSq = [];
    iteration = 0;
    lossByIteration = 0;
    minLoss = Inf;

    % Loop over epochs and store the lowest loss network, reshuffle the
    % mini-batch queue at each epoch
    for epoch = 1:NumEpochs
        reset(mbqTrain)
        shuffle(mbqTrain)

        % Loop over mini-batches
        while hasdata(mbqTrain)
            iteration = iteration + 1;

            % Get the next mini-batch
            [aECGbatch,mECGbatch,fECGbatch] = next(mbqTrain);

            % Evaluate the model gradients and loss
            [loss,gradients,state] = dlfeval(@modelLoss,wNet,aECGbatch, ...
                mECGbatch,fECGbatch,mECGWeight,fECGWeight);
            lossByIteration(iteration) = loss;

            % Update the network state
            wNet.State = state;

            % Update the network parameters using an Adam optimizer
            [wNet,trailingAvg,trailingAvgSq] = adamupdate(wNet,gradients, ...
                trailingAvg,trailingAvgSq,iteration,learnRate,decay); 
        end
        if loss < minLoss
            minLoss = loss;
            bestModel = wNet;
            % Uncomment the line below to save the best model so far
            %save Model.mat wNet
        end
    end
    wNet = bestModel;

    % Plot the loss by iteration
    figure
    plot(1:iteration,mag2db(lossByIteration))
    grid on
    title("Training Loss by Iteration")
    xlabel("Iteration")
    ylabel("Loss (dB)")
    axis tight
end

Figure contains an axes object. The axes object with title Training Loss by Iteration contains an object of type line.

Load a pretrained model if trainNetworkFlag is false. The model file will be unzipped to the tempdir directory. If you want the model at a different location, change modelFolder to the desired value.

if ~trainNetworkFlag
    % Download the pre-trained network
    modelZipFile = matlab.internal.examples.downloadSupportFile('SPT','data/fetal-ecg-source-separation-model.zip');
    modelFolder = fullfile(tempdir,'fetal-ecg-source-separation-model');
    if ~exist(modelFolder,'dir')
        unzip(modelZipFile,modelFolder);
    end
    modelFile = fullfile(modelFolder,'fetal-ecg-source-separation-model','Model.mat');
    load(modelFile)
end

Test Model

To test the trained network, use the previously created test datastore, testDS, that points to data from subject 10. This datastore reads the ECG data and the QRS peak location annotations so they can be used to validate the predicted mECG and fECG signals. As was done for the training datastore, transform the test datastore to get segmented and normalized aECG, mECG, and fECG signals.

testDS = transform(testDS,@(d,f)getECGSegments(d,segmentLength));

Call the predict method of the trained network to get separated mECG and fECG signals from an aECG input. Take for example iteration 3 of case C1 with and SNR of 9 dB. Estimate the fetal and maternal waveforms for that case as follows.

idx = contains(string(testDS.UnderlyingDatastores{1}.Files),fullfile("snr09dB","I3_C1.mat"));
ds = subset(testDS,idx);
data = read(ds);
mECG_QRS = data(:,4);
fECG_QRS = data(:,5);
[aECGbatch,mECGbatch,fECGbatch] = processMB(data(:,1),data(:,2),data(:,3));

% Move the aECGbatch into a dlarray and call the predict method of the
% trained network to estimate the source signals
dlaECG = dlarray(aECGbatch,"CBT");
[dlpred_fECG,dlpred_mECG] = predict(wNet,dlaECG);
pred_fECG = squeeze(extractdata(dlpred_fECG))';
pred_mECG = squeeze(extractdata(dlpred_mECG))';
pred_fECG = pred_fECG(:);
pred_mECG = pred_mECG(:);

Plot a few samples of the predicted waveforms. Overlay the annotated true QRS peaks using dotted lines.

figure
subplot(2,1,1)
plot(pred_fECG(1:2048))
xline([fECG_QRS{1}; fECG_QRS{2}],":k")
title("Predicted fECG")
axis([1 2048 -1.5 1])

subplot(2,1,2)
plot(pred_mECG(1:2048))
xline([mECG_QRS{1}; mECG_QRS{2}],":k")
title("Predicted mECG")
axis([1 2048 -1 2])

Figure contains 2 axes objects. Axes object 1 with title Predicted fECG contains 17 objects of type line, constantline. Axes object 2 with title Predicted mECG contains 6 objects of type line, constantline.

Plot predicted ECG signals for the case of high (12 dB) and low (3 dB) SNRs, for iteration 4 measurements, and for all three measurement cases using the plotPredictedECGs function. N and M can be set to plot segments N to N+M for the case at hand.

% Plot segment 4 for each case
N = 4;
M = 1;
plotPredictedECGs(wNet,testDS,"12","C0","I4",N,M)

Figure contains 5 axes objects. Axes object 1 with title aECG mixture, SNR = 12 dB, Case = C0, Iteration I4 contains an object of type line. Axes object 2 with title fECG target contains 18 objects of type line, constantline. Axes object 3 with title fECG predicted contains 18 objects of type line, constantline. Axes object 4 with title mECG target contains 6 objects of type line, constantline. Axes object 5 with title mECG predicted contains 6 objects of type line, constantline.

plotPredictedECGs(wNet,testDS,"12","C1","I4",N,M)

Figure contains 5 axes objects. Axes object 1 with title aECG mixture, SNR = 12 dB, Case = C1, Iteration I4 contains an object of type line. Axes object 2 with title fECG target contains 18 objects of type line, constantline. Axes object 3 with title fECG predicted contains 18 objects of type line, constantline. Axes object 4 with title mECG target contains 6 objects of type line, constantline. Axes object 5 with title mECG predicted contains 6 objects of type line, constantline.

plotPredictedECGs(wNet,testDS,"12","C3","I4",N,M)

Figure contains 5 axes objects. Axes object 1 with title aECG mixture, SNR = 12 dB, Case = C3, Iteration I4 contains an object of type line. Axes object 2 with title fECG target contains 17 objects of type line, constantline. Axes object 3 with title fECG predicted contains 17 objects of type line, constantline. Axes object 4 with title mECG target contains 6 objects of type line, constantline. Axes object 5 with title mECG predicted contains 6 objects of type line, constantline.

plotPredictedECGs(wNet,testDS,"03","C0","I4",N,M)

Figure contains 5 axes objects. Axes object 1 with title aECG mixture, SNR = 03 dB, Case = C0, Iteration I4 contains an object of type line. Axes object 2 with title fECG target contains 17 objects of type line, constantline. Axes object 3 with title fECG predicted contains 17 objects of type line, constantline. Axes object 4 with title mECG target contains 6 objects of type line, constantline. Axes object 5 with title mECG predicted contains 6 objects of type line, constantline.

plotPredictedECGs(wNet,testDS,"03","C1","I4",N,M)

Figure contains 5 axes objects. Axes object 1 with title aECG mixture, SNR = 03 dB, Case = C1, Iteration I4 contains an object of type line. Axes object 2 with title fECG target contains 18 objects of type line, constantline. Axes object 3 with title fECG predicted contains 18 objects of type line, constantline. Axes object 4 with title mECG target contains 6 objects of type line, constantline. Axes object 5 with title mECG predicted contains 6 objects of type line, constantline.

plotPredictedECGs(wNet,testDS,"03","C3","I4",N,M)

Figure contains 5 axes objects. Axes object 1 with title aECG mixture, SNR = 03 dB, Case = C3, Iteration I4 contains an object of type line. Axes object 2 with title fECG target contains 17 objects of type line, constantline. Axes object 3 with title fECG predicted contains 17 objects of type line, constantline. Axes object 4 with title mECG target contains 6 objects of type line, constantline. Axes object 5 with title mECG predicted contains 6 objects of type line, constantline.

fECG signals have faster heart rates than mECGs so we show fewer fECG points just for better visualization. The dotted lines on the plots correspond to annotated ground truth QRS peak locations. Proper location of QRS peaks is as important as the estimation of the overall signal shape. QRS peak locations allow estimation of heart rate and conditions like arrhythmia. Proper peak location should be considered when evaluating the performance of the source separation procedure.

Recall that the main purpose of this network is to extract fetal ECG signals, which are the most difficult to obtain from the mixture. In the W-Net architecture the primary target is the one estimated by the left U-Net branch, which corresponds to the first output of the network built in this example using the modelLoss function. Overall, the network does a very good job in estimating QRS peak locations and waveform shapes for both high- and low-SNR cases and different measurement conditions.

There are extreme measurement cases where the combination of noise, fetal movements, and heart rate variations are too severe for the network. For example, plot the ECG estimates for an SNR of 6 dB, measurement case C3, and iteration 5. In this case, the network fails to predict an acceptable fECG waveform.

plotPredictedECGs(wNet,testDS,"06","C3","I5",N,M)

Figure contains 5 axes objects. Axes object 1 with title aECG mixture, SNR = 06 dB, Case = C3, Iteration I5 contains an object of type line. Axes object 2 with title fECG target contains 17 objects of type line, constantline. Axes object 3 with title fECG predicted contains 17 objects of type line, constantline. Axes object 4 with title mECG target contains 6 objects of type line, constantline. Axes object 5 with title mECG predicted contains 6 objects of type line, constantline.

Plot the mean absolute deviation of the estimated fECG and mECG signals for all measurements of subject 10 using the computeErrorsForAllCases function.

computeErrorsForAllCases(wNet,testDS)

Figure contains 2 axes objects. Axes object 1 with title fECG mean absolute errors contains 3 objects of type line. These objects represent C0, C1, C3. Axes object 2 with title mECG mean absolute errors contains 3 objects of type line. These objects represent C0, C1, C3.

The errors do not decrease monotonically with SNR because of the variability of all the different combinations of noise, fetal movement, and heart-rate irregularities.

Conclusion

This example implements a W-Net architecture suitable for source separation of a mixture of two signals. The example analyzes the performance of the network using synthetic signal mixtures comprised of fetal and maternal ECG waveforms. The example shows that, in most scenarios, the network does a good job separating ECG signals and estimating correct waveform shapes and QRS peak locations.

References

[1] Goldberger, Ary L., Luis A. N. Amaral, Leon Glass, Jeffrey M. Hausdorff, Plamen Ch. Ivanov, Roger G. Mark, Joseph E. Mietus, George B. Moody, Chung-Kang Peng, and H. Eugene Stanley. “PhysioBank, PhysioToolkit, and PhysioNet.” Circulation 101, no. 23 (June 13, 2000): e215–20. https://doi.org/10.1161/01.CIR.101.23.e215.

[2] F. Andreotti, J. Behar, and G. D. Clifford. Fetal ECG Synthetic Database v1.0.0 (physionet.org), April 29, 2016, Version 1.0.0.

[3] F. Andreotti, J. Behar, S. Zaunseder, J. Oster, and G. D. Clifford. "An Open-Source Framework for Stress-Testing Non-Invasive Foetal ECG Extraction Algorithms." Physiological Measurement, Volume 37, Number 5, 2016.

[4] K. J. Lee and B. Lee, "End-to-End Deep Learning Architecture for Separating Maternal and Fetal ECGs Using W-Net," IEEE Access, Volume 10, pp. 39782-39788, 2022.

[5] O. Ronneberger, P. Fischer, and T. Brox. "U-Net: Convolutional Networks for Biomedical Image Segmentation", MICCAI, May 18, 2015.

Appendix: Helper Functions

The functions listed in this section are only for use in this example. They may change or be removed in a future release.

getECGSegments

This function creates aECG mixtures from mECG, fECG, and noise signals. The function breaks the ECG signals into segments of length segmentLength. Each segment is normalized and reshaped to a CBT format with C and B equal to 1. When the input to the function contains QRS peak locations, the function breaks the locations according to the start and end index of each segment.

function outputCell = getECGSegments(cellInput,segmentLength)

mECG = cellInput{1};
fECG = cellInput{2};
noise1 = cellInput{3};
noise2 = cellInput{4};
aECG = mECG + fECG + noise1 + noise2;

% Segment the data and keep indices so that we can also segment the QRS
% peak locations
[idxs,~] = buffer(1:size(mECG,1),segmentLength);
mECG = single(mECG(idxs)');
fECG = single(fECG(idxs)');
aECG = single(aECG(idxs)');

% Normalize
for idx = 1:size(mECG,1)
    mECG(idx,:) = rescale(mECG(idx,:),-1,1);
    mECG(idx,:) = mECG(idx,:) - median(mECG(idx,:));

    fECG(idx,:) = rescale(fECG(idx,:),-1,1);
    fECG(idx,:) = fECG(idx,:) - median(fECG(idx,:));
    
    aECG(idx,:) = rescale(aECG(idx,:),-1,1);
    aECG(idx,:) = aECG(idx,:) - median(aECG(idx,:));
end

numRows = size(mECG,1);

% CBT format C=1 B=numRows T=segmentLength
mECG = reshape(mECG,1,numRows,[]);
fECG = reshape(fECG,1,numRows,[]);
aECG = reshape(aECG,1,numRows,[]);

% Create cell array with individual elements --> CBT format C=1 B=1 T=segmentLength
mECGCell = mat2cell(mECG,1,ones(numRows,1),segmentLength)';
fECGCell = mat2cell(fECG,1,ones(numRows,1),segmentLength)';
aECGCell = mat2cell(aECG,1,ones(numRows,1),segmentLength)';

outputCell = [aECGCell mECGCell fECGCell];

if numel(cellInput) == 6
    mECG_QRSTmp = cellInput{5};
    fECG_QRSTmp = cellInput{6};

    segmentLimits = [idxs(1,:)' idxs(end,:)'];
    numSegments = size(segmentLimits,1);

    mECG_QRS = cell(numSegments,1);
    fECG_QRS = cell(numSegments,1);
    for idx = 1:numSegments
        mECG_QRS{idx} = mECG_QRSTmp(mECG_QRSTmp >= segmentLimits(idx,1) & ...
            mECG_QRSTmp <= segmentLimits(idx,2));
        fECG_QRS{idx} = fECG_QRSTmp(fECG_QRSTmp >= segmentLimits(idx,1) & ...
            fECG_QRSTmp <= segmentLimits(idx,2));
    end
    outputCell = [outputCell mECG_QRS fECG_QRS];
end
end

processMB

This function converts cell array inputs, containing ECG segments, to mini-batches with CBT format.

function [aECGbatch,mECGbatch,fECGbatch] = processMB(aECGCell,mECGCell,fECGCell)

aECGbatch = cat(2,aECGCell{:});
mECGbatch = cat(2,mECGCell{:});
fECGbatch = cat(2,fECGCell{:});
end

modelLoss

This function feeds an aECG input to the network and computes the gradient and resulting loss.

function [loss,grads,state] = modelLoss(net,aECG,mECG,fECG,mECGWeight,fECGWeight)

[fECGpred,mECGpred,state] = net.forward(aECG);

loss = stripdims(fECGWeight*mean(abs(fECG-fECGpred),"all") + ...
    mECGWeight*mean(abs(mECG-mECGpred),"all"));

grads = dlgradient(loss,net.Learnables);

loss = double(gather(extractdata(loss)));
end

plotPredictedECGs

This function plots actual and predicted ECG signals for a specified measurement case, iteration, and SNR value. The function plots segments N to N+M.

function plotPredictedECGs(wNet,testDS,SNRstr,caseStr,iterStr,N,M)

% testDS is datastore pointing to test data
% SNRstr can be "12", "09", "06", "03"
% iterStr can be "I1", "I2", "I3", "I4", "I5"
% caseStr can be "C0", "C1", "C3"

dataIdx = N:N+M;
% Get a datastore with the requested case
idx = contains(string(testDS.UnderlyingDatastores{1}.Files), ...
    fullfile("snr"+SNRstr+"dB",iterStr+"_"+caseStr+".mat"));
ds = subset(testDS,idx);
data = read(ds);
data = data(dataIdx,:);
mECG_QRS = data(:,4);
fECG_QRS = data(:,5);
[aECGbatch,mECGbatch,fECGbatch] = processMB(data(:,1),data(:,2),data(:,3));

% Move the aECGbatch into a dlarray and call the predict method of the
% trained network to estimate the source signals
dlaECG = dlarray(aECGbatch,"CBT");
[dlpred_fECG,dlpred_mECG] = predict(wNet,dlaECG);
pred_fECG = extractdata(dlpred_fECG);
pred_mECG = extractdata(dlpred_mECG);

% Plot the results
aECG = squeeze(aECGbatch)';
mECG = squeeze(mECGbatch)';
fECG = squeeze(fECGbatch)';
pred_fECG = squeeze(pred_fECG)';
pred_mECG = squeeze(pred_mECG)';

aECG = aECG(:);
mECG = mECG(:);
fECG = fECG(:);
pred_mECG = pred_mECG(:);
pred_fECG = pred_fECG(:);

mECG_QRS = cat(1,mECG_QRS{:});
fECG_QRS = cat(1,fECG_QRS{:});

mECG_QRS = mECG_QRS - ((N-1)*1024-1) - 1;
fECG_QRS = fECG_QRS - ((N-1)*1024-1) - 1;

titleStr = "SNR = "+SNRstr+" dB, Case = "+caseStr+", Iteration "+iterStr;
figure
subplot(3,2,[1 2])
plot(aECG)
title("aECG mixture, "+titleStr)
minECG = min(aECG);
maxECG = max(aECG);
axis([1 length(aECG) minECG-abs(minECG*0.35) maxECG+maxECG*0.35])

minfECG = min(fECG);
maxfECG = max(fECG);
minPredfECG = gather(min(pred_fECG));
maxPredfECG = gather(max(pred_fECG));
minECG = min(minfECG,minPredfECG);
maxECG = max(maxfECG,maxPredfECG);

subplot(3,2,3)
plot(fECG)
xline(fECG_QRS,":k")
title("fECG target")
axis([1 floor(length(fECG)/2) minECG-abs(minECG*0.35) maxECG+maxECG*0.35])

subplot(3,2,4)
plot(pred_fECG)
xline(fECG_QRS,":k")
title("fECG predicted")
axis([1 floor(length(pred_fECG)/2) minECG-abs(minECG*0.35) maxECG+maxECG*0.35])

minmECG = min(mECG);
maxmECG = max(mECG);
minPredmECG = gather(min(pred_mECG));
maxPredmECG = gather(max(pred_mECG));
minECG = min(minmECG,minPredmECG);
maxECG = max(maxmECG,maxPredmECG);

subplot(3,2,5)
plot(mECG)
xline(mECG_QRS,":k")
title("mECG target")
axis([1 length(mECG) minECG-abs(minECG*0.35) maxECG+maxECG*0.35])

subplot(3,2,6)
plot(pred_mECG)
xline(mECG_QRS,":k")
title("mECG predicted")
axis([1 length(pred_mECG) minECG-abs(minECG*0.35) maxECG+maxECG*0.35])
end

computeErrorsForAllCases

This function computes the mean absolute error between actual fECG and mECG signals and predicted ones for all SNR values, measurement cases, and iterations of subject 10.

function computeErrorsForAllCases(wNet,testDS)

% testDS is datastore pointing to test data of subject 10

SNRVect = ["03" "06" "09" "12"];
caseVect = ["C0" "C1" "C3"];

 for SNRidx = 1:numel(SNRVect)
     SNRstr = SNRVect(SNRidx);
     for caseIdx = 1:numel(caseVect)
         caseStr = caseVect(caseIdx);

         idx = contains(string(testDS.UnderlyingDatastores{1}.Files),fullfile("snr"+SNRstr+"dB","I1_"+caseStr+".mat"));
         idx = idx | contains(string(testDS.UnderlyingDatastores{1}.Files),fullfile("snr"+SNRstr+"dB","I2_"+caseStr+".mat"));
         idx = idx | contains(string(testDS.UnderlyingDatastores{1}.Files),fullfile("snr"+SNRstr+"dB","I3_"+caseStr+".mat"));
         idx = idx | contains(string(testDS.UnderlyingDatastores{1}.Files),fullfile("snr"+SNRstr+"dB","I4_"+caseStr+".mat"));
         idx = idx | contains(string(testDS.UnderlyingDatastores{1}.Files),fullfile("snr"+SNRstr+"dB","I5_"+caseStr+".mat"));

         ds = subset(testDS,idx);
         data = readall(ds);
         [aECGbatch,mECGbatch,fECGbatch] = processMB(data(:,1),data(:,2),data(:,3));

         dlaECG = dlarray(aECGbatch,"CBT");
         [dlpred_fECG,dlpred_mECG] = predict(wNet,dlaECG);

         pred_mECG = extractdata(dlpred_mECG);
         pred_fECG = extractdata(dlpred_fECG);
                        
         errMtx(caseIdx,SNRidx) = 0.5*mean(abs(mECGbatch - pred_mECG),'all') +  0.5*mean(abs(fECGbatch - pred_fECG),'all');
         errMtxFecg(caseIdx,SNRidx) = mean(abs(fECGbatch - pred_fECG),'all');
         errMtxMecg(caseIdx,SNRidx) = mean(abs(mECGbatch - pred_mECG),'all');
     end
 end

figure
subplot(2,1,1)
plot([3 6 9 12],errMtxFecg');
title("fECG mean absolute errors")
xlabel("SNR")
ylabel("MAE")
legend("C0","C1","C3")
grid on
axis tight
subplot(2,1,2)
plot([3 6 9 12],errMtxMecg');
title("mECG mean absolute errors")
xlabel("SNR")
ylabel("MAE")
legend("C0","C1","C3")
grid on
axis tight
end

createWNet

This function implements a W-Net architecture and returns a layer graph.

function lgraph = createWNet(inputSize,filterSizeLeft,numFiltersLeft,filterSizeRight,numFiltersRight)

lgraph = layerGraph;

inputLayer = sequenceInputLayer(1,MinLength=inputSize,Name="inputMixture");
lgraph = addLayers(lgraph,inputLayer);

% Define left and right U-Net branches
% Layer name conventions - left means it belongs to left U-Net
%                        - ds means down sample, us means upsample branch,
%                          bridge is the final row in the autoencoder
%                        - i_j means ith row, jth layer

% Add left branch U-Net
lgraph = createUNet(lgraph,filterSizeLeft,numFiltersLeft,"left");
lgraph = connectLayers(lgraph,'inputMixture','conv1d_left_ds_1_1');

% Add right branch U-Net
lgraph = createUNet(lgraph,filterSizeRight,numFiltersRight,"right");
lgraph = connectLayers(lgraph,'inputMixture','conv1d_right_ds_1_1');

% Connect right U-Net encoder branch to subtraction layers
lgraph = connectLayers(lgraph,"avgpool1d_right_1_to_2","subtraction_2/in2");
lgraph = connectLayers(lgraph,"avgpool1d_right_2_to_3","subtraction_3/in2");
lgraph = connectLayers(lgraph,"avgpool1d_right_3_to_4","subtraction_4/in2");
lgraph = connectLayers(lgraph,"avgpool1d_right_4_to_5","subtraction_5/in2");

end

createUNet

This function implements the left and right U-Net branches needed to build a W-Net architecture.

function lgraph = createUNet(lgraph,filterSize,numFilters,branchStr)

% branchStr can be "left" or "right"
%
% Layer name conventions - left means it belongs to left U-Net
%                        - ds means down sample, us means upsample branch
%                        - i_j means ith row, jth layer

numFiltScale = 1 + double(branchStr == "right");
if branchStr == "left"
    branchStrOutput = "outputLayer_left_targetSignal";
else
    branchStrOutput = "outputLayer_right_secondarySignal";
end

unet = [
% Row 1 encoder branch
convolution1dLayer(filterSize, numFilters, Padding="same", Name="conv1d_"+branchStr+"_ds_1_1")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_ds_1_1")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_ds_1_1")

convolution1dLayer(filterSize, numFilters, Padding="same", Name="conv1d_"+branchStr+"_ds_1_2")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_ds_1_2")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_ds_1_2")

convolution1dLayer(filterSize, numFilters, Padding="same",Name="conv1d_"+branchStr+"_ds_1_3")
batchNormalizationLayer("Name","batchnorm_"+branchStr+"_ds_1_3")
leakyReluLayer(0.01,"Name","leakyrelu_"+branchStr+"_ds_1_3")

averagePooling1dLayer(2, Padding="same", Stride=2, Name="avgpool1d_"+branchStr+"_1_to_2")
];

% Row 2 encoder branch
if branchStr == "left"
    unet = [unet
        functionLayer(@minus,NumInputs=2,Formattable=true,Acceleratable=true,Name="subtraction_2");
        tanhLayer(Name="tanh_2")
        ];
end

unet = [unet
convolution1dLayer(filterSize, numFilters*2, Padding="same", Name="conv1d_"+branchStr+"_ds_2_1")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_ds_2_1")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_ds_2_1")

convolution1dLayer(filterSize, numFilters*2, Padding="same", Name="conv1d_"+branchStr+"_ds_2_2")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_ds_2_2")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_ds_2_2")

averagePooling1dLayer(2, Padding="same", Stride=2, Name="avgpool1d_"+branchStr+"_2_to_3")
];

% Row 3 encoder branch
if branchStr == "left"
    unet = [unet
        functionLayer(@minus,NumInputs=2,Formattable=true,Acceleratable=true,Name="subtraction_3");
        tanhLayer(Name="tanh_3")];
end

unet = [unet
convolution1dLayer(filterSize, numFilters*4, Padding="same", Name="conv1d_"+branchStr+"_ds_3_1")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_ds_3_1")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_ds_3_1")

convolution1dLayer(filterSize, numFilters*4, Padding="same", Name="conv1d_"+branchStr+"_ds_3_2")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_ds_3_2")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_ds_3_2")

averagePooling1dLayer(2, Padding="same", Stride=2, Name="avgpool1d_"+branchStr+"_3_to_4")];

% Row 4 encoder branch
if branchStr == "left"
    unet = [unet
        functionLayer(@minus,NumInputs=2,Formattable=true,Acceleratable=true,Name="subtraction_4");
        tanhLayer(Name="tanh_4")
        ];
end

unet = [unet
convolution1dLayer(filterSize, numFilters*8, Padding="same", Name="conv1d_"+branchStr+"_ds_4_1")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_ds_4_1")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_ds_4_1")

convolution1dLayer(filterSize, numFilters*8, Padding="same", Name="conv1d_"+branchStr+"_ds_4_2")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_ds_4_2")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_ds_4_2")

averagePooling1dLayer(2, Padding="same", Stride=2, Name="avgpool1d_"+branchStr+"_4_to_5")
];

% Row 5 encoder branch
if branchStr == "left"
    unet = [unet
        functionLayer(@minus,NumInputs=2,Formattable=true,Acceleratable=true,Name="subtraction_5");
        tanhLayer(Name="tanh_5")
        ];
end

unet = [unet
convolution1dLayer(filterSize, numFilters*16, Padding="same", Name="conv1d_"+branchStr+"_ds_5_1")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_ds_5_1")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_ds_5_1")

convolution1dLayer(filterSize, numFilters*16, Padding="same", Name="conv1d_"+branchStr+"_ds_5_2")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_ds_5_2")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_ds_5_2")

averagePooling1dLayer(2, Padding="same", Stride=2, Name="avgpool1d_"+branchStr+"_5_to_6")

% Row 6 - bridge
convolution1dLayer(filterSize, numFilters*16*numFiltScale, Padding="same", Name="conv1d_"+branchStr+"_bridge_6_1")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_bridge_6_1")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_bridge_6_1")

convolution1dLayer(filterSize, numFilters*16*numFiltScale, Padding="same", Name="conv1d_"+branchStr+"_bridge_6_2")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_bridge_6_2")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_bridge_6_2")

transposedConv1dLayer(filterSize, numFilters*16*numFiltScale, Stride=2, Cropping="same", Name="transconv1d_"+branchStr+"_us_6_to_5")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_us_6_to_5")

% Row 5 decoder branch
concatenationLayer(1, 2, Name="concat_"+branchStr+"_5")

convolution1dLayer(filterSize, numFilters*16, Padding="same", Name="conv1d_"+branchStr+"_us_5_1")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_us_5_1")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_us_5_1")

convolution1dLayer(filterSize, numFilters*16, Padding="same", Name="conv1d_"+branchStr+"_us_5_2")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_us_5_2")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_us_5_2")

transposedConv1dLayer(filterSize, numFilters*16, Stride=2, Cropping="same", Name="transconv1d_"+branchStr+"_us_5_to_4")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_us_5_to_4")

% Row 4 decoder branch
concatenationLayer(1, 2, Name="concat_"+branchStr+"_4")

convolution1dLayer(filterSize, numFilters*8, Padding="same", Name="conv1d_"+branchStr+"_us_4_1")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_us_4_1")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_us_4_1")

convolution1dLayer(filterSize, numFilters*8, Padding="same", Name="conv1d_"+branchStr+"_us_4_2")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_us_4_2")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_us_4_2")

transposedConv1dLayer(filterSize, numFilters*8, Stride=2, Cropping="same", Name="transconv1d_"+branchStr+"_us_4_to_3")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_us_4_to_3")

% Row 3 decoder branch
concatenationLayer(1, 2, Name="concat_"+branchStr+"_3")

convolution1dLayer(filterSize, numFilters*4, Padding="same", Name="conv1d_"+branchStr+"_us_3_1")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_us_3_1")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_us_3_1")

convolution1dLayer(filterSize, numFilters*4, Padding="same", Name="conv1d_"+branchStr+"_us_3_2")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_us_3_2")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_us_3_2")

transposedConv1dLayer(filterSize, numFilters*4, Stride=2, Cropping="same", Name="transconv1d_"+branchStr+"_us_3_to_2")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_us_3_to_2")

% Row 2 decoder branch
concatenationLayer(1, 2, Name="concat_"+branchStr+"_2")

convolution1dLayer(filterSize, numFilters*2, Padding="same", Name="conv1d_"+branchStr+"_us_2_1")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_us_2_1")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_us_2_1")

convolution1dLayer(filterSize, numFilters*2, Padding="same", Name="conv1d_"+branchStr+"_us_2_2")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_us_2_2")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_us_2_2")

transposedConv1dLayer(filterSize, numFilters*2, Stride=2, Cropping="same", Name="transconv1d_"+branchStr+"_us_2_to_1")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_us_2_to_1")

% Row 1 decoder branch
concatenationLayer(1, 2, Name="concat_"+branchStr+"_1")

convolution1dLayer(filterSize, numFilters, Padding="same", Name="conv1d_"+branchStr+"_us_1_1")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_us_1_1")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_us_1_1")

convolution1dLayer(filterSize, numFilters, Padding="same", Name="conv1d_"+branchStr+"_us_1_2")
batchNormalizationLayer(Name="batchnorm_"+branchStr+"_us_1_2")
leakyReluLayer(0.01,Name="leakyrelu_"+branchStr+"_us_1_2")

convolution1dLayer(filterSize, numFilters, Padding="same",Name="conv1d_"+branchStr+"_us_1_3")
batchNormalizationLayer("Name","batchnorm_"+branchStr+"_us_1_3")
leakyReluLayer(0.01,"Name","leakyrelu_"+branchStr+"_us_1_3")

convolution1dLayer(filterSize, 1, Padding="same",Name=branchStrOutput)
];

lgraph = addLayers(lgraph,unet);
lgraph = connectLayers(lgraph,"leakyrelu_"+branchStr+"_ds_5_2","concat_"+branchStr+"_5/in2");
lgraph = connectLayers(lgraph,"leakyrelu_"+branchStr+"_ds_4_2","concat_"+branchStr+"_4/in2");
lgraph = connectLayers(lgraph,"leakyrelu_"+branchStr+"_ds_3_2","concat_"+branchStr+"_3/in2");
lgraph = connectLayers(lgraph,"leakyrelu_"+branchStr+"_ds_2_2","concat_"+branchStr+"_2/in2");
lgraph = connectLayers(lgraph,"leakyrelu_"+branchStr+"_ds_1_3","concat_"+branchStr+"_1/in2");
end

See Also

Objects