Main Content

Time-Frequency Feature Embedding with Deep Metric Learning

This example shows how to use deep metric learning with a supervised contrastive loss to construct feature embeddings based on a time-frequency analysis of electroencephalographic (EEG) signals. The learned time-frequency embeddings reduce the dimensionality of the time-series data by a factor of 16. You can use these embeddings to classify EEG time-series from persons with and without epilepsy using a support vector machine classifier.

Deep Metric Learning

Deep metric learning attempts to learn a nonlinear feature embedding, or encoder, that reduces the distance (a metric) between examples from the same class and increases the distance between examples from different classes. Loss functions that work in this way are often referred to as contrastive. This example uses supervised deep metric learning with a particular contrastive loss function called the normalized temperature-scaled cross-entropy loss [3],[4],[8]. The figure shows the general workflow for this supervised deep metric learning.


Positive pairs refer to training samples with the same label, while negative pairs refer to training samples with different labels. A distance, or similarity, matrix is formed from the positive and negative pairs. In this example, the cosine similarity matrix is used. From these distances, losses are computed and aggregated (reduced) to form a single scalar-valued loss for use in gradient-descent learning.

Deep metric learning is also applicable in weakly supervised, self-supervised, and unsupervised contexts. There is a wide variety of distance (metrics) measures, losses, reducers, and regularizers that are employed in deep metric learning.

Data — Description, Attribution, and Download Instructions

The data used in this example is the Bonn EEG Data Set. The data is currently available at EEG Data Download and Ralph Andrzejak's EEG data download page. See Ralph Andrzejak's EEG data for legal conditions on the use of the data. The authors have kindly permitted the use of the data in this example.

The data in this example were first analyzed and reported in:

Andrzejak, Ralph G., Klaus Lehnertz, Florian Mormann, Christoph Rieke, Peter David, and Christian E. Elger. "Indications of Nonlinear Deterministic and Finite-Dimensional Structures in Time Series of Brain Electrical Activity: Dependence on Recording Region and Brain State." Physical Review E 64, no. 6 (2001). <>

The data consists of five sets of 100 single-channel EEG recordings. The resulting single-channel EEG recordings were selected from 128-channel EEG recordings after visually inspecting each channel for obvious artifacts and satisfying a weak stationarity criterion. See the linked paper for details.

The original paper designates these five sets as A-E. Each recording is 23.6 seconds in duration sampled at 173.61 Hz. Each time series contains 4097 samples. The conditions are as follows:

A -- Normal subjects with eyes open

B -- Normal subjects with eyes closed

C -- Seizure-free recordings from patients with epilepsy. Recording from hippocampus in the hemisphere opposite the epileptogenic zone

D -- Seizure-free recordings obtained from patients with epilepsy. Recordings from epileptogenic zone.

E - Recordings from patients with epilepsy showing seizure activity.

The zip files corresponding to this data are labeled as (A), (B), (C), (D), and (E).

The example assumes you have downloaded and unzipped the zip files into folders named Z, O, N, F, and S respectively. In MATLAB® you can do this by creating a parent folder and using that as the OUTPUTDIR variable in the unzip command. This example uses the folder designated by MATLAB as tempdir as the parent folder. If you choose to use a different folder, adjust the value of parentDir accordingly. The following code assumes that all the .zip files have been downloaded into parentDir. Unzip the files by folder into a subfolder called BonnEEG.

parentDir = tempdir;
dataDir = fullfile(parentDir,'BonnEEG');

Creating in-memory data and labels

The individual EEG time series are stored as .txt files in each of the Z, N, O, F, and S folders under dataDir. Use a tabularTextDatastore to read the data. Create a tabular text datastore and create a categorical array of signal labels based on the folder names.

tds = tabularTextDatastore(dataDir,'IncludeSubfolders',true,'FileExtensions','.txt');

The zip files were created on a macOS and accordingly there may be a MACOSX folder created with unzip that results in extra files. If those exist, remove them.

extraTXT = contains(tds.Files,'__MACOSX');
tds.Files(extraTXT) = [];

Create labels for the data based on the first letter of the text file name.

labels = filenames2labels(tds.Files,'ExtractBetween',[1 1]);

Each read of the tabular text datastore creates a table containing the data. Create a cell array of all signals reshaped as row vectors so they conform with the deep learning networks used in the example.

ii = 1;
eegData = cell(numel(labels),1);
while hasdata(tds)
    tsTable = read(tds);
    ts = tsTable.Var1;
    eegData{ii} = reshape(ts,1,[]);
    ii = ii+1;

Time-Frequency Feature Embedding Deep Network

Here we construct a deep learning network that creates an embedding of the input signal based on a time-frequency analysis.

TFnet = [sequenceInputLayer(1,'MinLength',4097,'Name',"input")
    'FrequencyLimits',[0 0.23])
TFnet = dlnetwork(TFnet);

After the input layer, the network obtains the continuous wavelet transform (CWT) of the data using the analytic Morlet wavelet. The output of cwtLayer (Wavelet Toolbox) is the magnitude of the CWT, or scalogram. Unlike the analyses in [1],[2], and [7], no pre-processing bandpass filter is used in this network. Instead, the CWT is obtained only over the frequency range of [0.0, 0.23] cycles/sample which is equivalent to [0,39.93] Hz for the sample rate of 173.61 Hz. This is the approximate range of the bandpass filter applied to the data before analysis in [1]. After the network obtains the scalogram, the network cascades a series of 2-D convolutional, batch normalization, and RELU layers. The final layer is a fully connected layer with 256 output units. This results in a 16-fold reduction in the size of the input. See [7] for another scalogram-based analysis of this data and [2] for another wavelet-based analysis using the tunable Q-factor wavelet transform.

Differentiating Normal, Pre-seizure, and Seizure EEG

Given the five conditions present in the data, there are multiple meaningful and clinically informative ways to partition the data. One relevant way is to group the Z and O labels (non-epileptic subjects with eyes open and closed) as "Normal". Similarly, the two conditions recorded in the persons with epilepsy without overt seizure activity (N and F) may be grouped as "Pre-seizure". Finally, we designate the recordings obtained in epileptic subjects with seizure activity as "Seizure". To create labels, which may be cast to numeric values during training, designate these three classes as:

  • 0 -- "Normal"

  • 1 -- "Pre-seizure"

  • 2 -- "Seizure"

Partition the data into training and test sets. First, create the new labels in order to partition the data. Examine the number of examples in each class.

labelsPS = labels;
labelsPS = removecats(labelsPS,{'F','N','O','S','Z'});
labelsPS(labels == categorical("Z") | labels == categorical("O")) = categorical("0");
labelsPS(labels == categorical("N") | labels == categorical("F")) = categorical("1");
labelsPS(labels == categorical("S")) = categorical("2");
labelsPS(isundefined(labelsPS)) = [];
     0      200 
     1      200 
     2      100 

The resulting classes are unbalanced with twice as many signals in the "Normal" and "Pre-seizure" categories as in the "Seizure" category. Partition the data for training the encoder and the hold-out test set. Allocate 80% of the data to the training set and 20% to the test set.

idxPS = splitlabels(labelsPS,[0.8 0.2]);
TrainDataPS = eegData(idxPS{1});
TrainLabelsPS = labelsPS(idxPS{1});
testDataPS = eegData(idxPS{2});
testLabelsPS = labelsPS(idxPS{2});

Training the Encoder

To train the encoder, set trainEmbedder to true. To skip the training and load a pretrained encoder and corresponding embeddings, set trainEmbedder to false and go to the Test Data Embeddings section.

trainEmbedder = true;

Because this example uses a custom loss function, you must use a custom training loop. To manage data through the custom training loop, use a signalDatastore (Signal Processing Toolbox) with a custom read function that normalizes the input signals to have zero mean and unit standard deviation.

if trainEmbedder
    sdsTrain = signalDatastore(TrainDataPS,MemberNames = string(TrainLabelsPS));
    transTrainDS = transform(sdsTrain,@(x,info)helperReadData(x,info),'IncludeInfo',true);

Train the network by measuring the normalized temperature-controlled cross-entropy loss between embeddings obtained from identical classes (corresponding to positive pairs) and disparate classes (corresponding to negative pairs) in each mini-batch. The custom loss function computes the cosine similarity between each training example, obtaining a M-by-M similarity matrix, where M is the mini-batch size. The function computes the normalized temperature-controlled cross entropy for the similarity matrix with the temperature parameter equal to 0.07. The function calculates the scalar loss as the mean of the mini-batch losses.

Specify Training Options

The model parameters are updated based on the loss using an Adam optimizer.

Train the encoder for 150 epochs with a mini-batch size of 50, a learning rate of 0.001, and an L2-regularization rate of 0.01.

if trainEmbedder
    NumEpochs = 150;
    minibatchSize = 50;
    learnRate = 0.001;
    l2Regularization = 1e-2;    

Calculate the number of iterations per epoch and the total number of iterations to display training progress.

if trainEmbedder
    numObservations = numel(TrainDataPS);
    numIterationsPerEpoch = floor(numObservations./minibatchSize);
    numIterations = NumEpochs*numIterationsPerEpoch;

Create a minibatchqueue object to manage data flow through the custom training loop.

if trainEmbedder
    numOutputs = 2;
    mbqTrain = minibatchqueue(transTrainDS,numOutputs,...
        'minibatchFormat', {'CBT','B'});

Train the encoder.

if trainEmbedder
    progress = "final-loss";
    if progress == "training-progress"
        lineLossTrain = animatedline;
        ylim([0 inf])
        grid on
    % Initialize some training loop variables
    trailingAvg = [];
    trailingAvgSq = [];
    iteration = 1;
    lossByIteration = zeros(numIterations,1);

    % Loop over epochs and time the epochs
    start = tic;

    for epoch = 1:NumEpochs
        % Shuffle the mini-batches each epoch

        % Loop over mini-batches
        while hasdata(mbqTrain)
            % Get the next mini-batch and one-hot coded targets
            [dlX,Y] = next(mbqTrain);
            % Evaluate the model gradients and contrastive loss
            [gradients, loss, state] = dlfeval(@modelGradcontrastiveLoss,TFnet,dlX,Y);
            if progress == "final-loss"
                lossByIteration(iteration) = loss;
            % Update the gradients with the L2-regularization rate
            idx = TFnet.Learnables.Parameter == "Weights";
            gradients(idx,:) = ...
                dlupdate(@(g,w) g + l2Regularization*w, gradients(idx,:), TFnet.Learnables(idx,:));
            % Update the network state
            TFnet.State = state;
            % Update the network parameters using an Adam optimizer
            [TFnet,trailingAvg,trailingAvgSq] = adamupdate(...

            % Display the training progress
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            if progress == "training-progress"
                title("Epoch: " + epoch + ", Elapsed: " + string(D))
            iteration = iteration + 1;

        disp("Training loss after epoch " + epoch + ": " + loss);
    if progress == "final-loss"
        grid on
        title('Training Loss by Iteration')
Training loss after epoch 1: 1.4759
Training loss after epoch 2: 1.5684
Training loss after epoch 3: 1.0331
Training loss after epoch 4: 1.1621
Training loss after epoch 5: 0.70297
Training loss after epoch 6: 0.29956
Training loss after epoch 7: 0.42671
Training loss after epoch 8: 0.23963
Training loss after epoch 9: 0.021723
Training loss after epoch 10: 0.50336
Training loss after epoch 11: 0.34225
Training loss after epoch 12: 0.63325
Training loss after epoch 13: 0.31603
Training loss after epoch 14: 0.25883
Training loss after epoch 15: 0.52879
Training loss after epoch 16: 0.27623
Training loss after epoch 17: 0.070335
Training loss after epoch 18: 0.073039
Training loss after epoch 19: 0.2657
Training loss after epoch 20: 0.10312
Training loss after epoch 21: 0.33435
Training loss after epoch 22: 0.24089
Training loss after epoch 23: 0.083583
Training loss after epoch 24: 0.33138
Training loss after epoch 25: 0.006466
Training loss after epoch 26: 0.44036
Training loss after epoch 27: 0.028106
Training loss after epoch 28: 0.14215
Training loss after epoch 29: 0.018414
Training loss after epoch 30: 0.018228
Training loss after epoch 31: 0.026751
Training loss after epoch 32: 0.026275
Training loss after epoch 33: 0.13545
Training loss after epoch 34: 0.029467
Training loss after epoch 35: 0.0088911
Training loss after epoch 36: 0.12077
Training loss after epoch 37: 0.1113
Training loss after epoch 38: 0.14529
Training loss after epoch 39: 0.10718
Training loss after epoch 40: 0.10141
Training loss after epoch 41: 0.018227
Training loss after epoch 42: 0.0086456
Training loss after epoch 43: 0.025808
Training loss after epoch 44: 0.00021023
Training loss after epoch 45: 0.0013423
Training loss after epoch 46: 0.0020328
Training loss after epoch 47: 0.012152
Training loss after epoch 48: 0.00025792
Training loss after epoch 49: 0.0010626
Training loss after epoch 50: 0.0015668
Training loss after epoch 51: 0.00048469
Training loss after epoch 52: 0.00073284
Training loss after epoch 53: 0.00043141
Training loss after epoch 54: 0.0009649
Training loss after epoch 55: 0.00014656
Training loss after epoch 56: 0.00024468
Training loss after epoch 57: 0.00092313
Training loss after epoch 58: 0.00022878
Training loss after epoch 59: 6.3505e-05
Training loss after epoch 60: 5.0711e-05
Training loss after epoch 61: 0.0006025
Training loss after epoch 62: 0.00010356
Training loss after epoch 63: 0.00018479
Training loss after epoch 64: 0.00042666
Training loss after epoch 65: 6.88e-05
Training loss after epoch 66: 0.00019625
Training loss after epoch 67: 0.00064875
Training loss after epoch 68: 0.00017705
Training loss after epoch 69: 0.00086301
Training loss after epoch 70: 0.00044735
Training loss after epoch 71: 0.00099668
Training loss after epoch 72: 3.7804e-05
Training loss after epoch 73: 9.1751e-05
Training loss after epoch 74: 2.6748e-05
Training loss after epoch 75: 0.0012345
Training loss after epoch 76: 0.00019493
Training loss after epoch 77: 0.00058993
Training loss after epoch 78: 0.0024207
Training loss after epoch 79: 7.1345e-05
Training loss after epoch 80: 0.00015598
Training loss after epoch 81: 9.3623e-05
Training loss after epoch 82: 8.9839e-05
Training loss after epoch 83: 0.0024844
Training loss after epoch 84: 0.0001383
Training loss after epoch 85: 0.00027976
Training loss after epoch 86: 0.17246
Training loss after epoch 87: 0.61378
Training loss after epoch 88: 0.41423
Training loss after epoch 89: 0.35526
Training loss after epoch 90: 0.081963
Training loss after epoch 91: 0.09392
Training loss after epoch 92: 0.026856
Training loss after epoch 93: 0.18554
Training loss after epoch 94: 0.04293
Training loss after epoch 95: 0.0002686
Training loss after epoch 96: 0.0071139
Training loss after epoch 97: 0.0028931
Training loss after epoch 98: 0.029305
Training loss after epoch 99: 0.0080128
Training loss after epoch 100: 0.0018248
Training loss after epoch 101: 0.00012145
Training loss after epoch 102: 7.6166e-05
Training loss after epoch 103: 0.0001156
Training loss after epoch 104: 8.262e-05
Training loss after epoch 105: 0.00023958
Training loss after epoch 106: 0.00016227
Training loss after epoch 107: 0.00025268
Training loss after epoch 108: 0.0022929
Training loss after epoch 109: 0.00029386
Training loss after epoch 110: 0.00029691
Training loss after epoch 111: 0.00033467
Training loss after epoch 112: 5.31e-05
Training loss after epoch 113: 0.00013522
Training loss after epoch 114: 1.4335e-05
Training loss after epoch 115: 0.0015768
Training loss after epoch 116: 2.4165e-05
Training loss after epoch 117: 0.00031281
Training loss after epoch 118: 3.4592e-05
Training loss after epoch 119: 7.1151e-05
Training loss after epoch 120: 0.00020099
Training loss after epoch 121: 1.7647e-05
Training loss after epoch 122: 0.00010945
Training loss after epoch 123: 0.0012003
Training loss after epoch 124: 4.5947e-05
Training loss after epoch 125: 0.00043231
Training loss after epoch 126: 7.3228e-05
Training loss after epoch 127: 2.3522e-05
Training loss after epoch 128: 0.00014366
Training loss after epoch 129: 0.00010692
Training loss after epoch 130: 0.00066842
Training loss after epoch 131: 9.2536e-06
Training loss after epoch 132: 0.0007364
Training loss after epoch 133: 3.0709e-05
Training loss after epoch 134: 5.4056e-05
Training loss after epoch 135: 3.3361e-05
Training loss after epoch 136: 8.1937e-05
Training loss after epoch 137: 0.00012198
Training loss after epoch 138: 3.9838e-05
Training loss after epoch 139: 0.00025224
Training loss after epoch 140: 4.9974e-05
Training loss after epoch 141: 8.302e-05
Training loss after epoch 142: 2.009e-05
Training loss after epoch 143: 7.2674e-05
Training loss after epoch 144: 4.8355e-05
Training loss after epoch 145: 0.0008231
Training loss after epoch 146: 0.00017177
Training loss after epoch 147: 3.4427e-05
Training loss after epoch 148: 0.0095201
Training loss after epoch 149: 0.026009
Training loss after epoch 150: 0.071619

Test Data Embeddings

Obtain the embeddings for the test data. If you set trainEmbedder to false, you can load the trained encoder and embeddings obtained using the helperEmbedTestFeatures function.

if trainEmbedder
    finalEmbeddingsTable = helperEmbedTestFeatures(TFnet,testDataPS,testLabelsPS);
    load('TFnet.mat'); %#ok<*UNRCH>

Use a support vector machine (SVM) classifier with a Gaussian kernel to classify the embeddings.

template = templateSVM(...
    'KernelFunction', 'gaussian', ...
    'PolynomialOrder', [], ...
    'KernelScale', 4, ...
    'BoxConstraint', 1, ...
    'Standardize', true);
classificationSVM = fitcecoc(...
    finalEmbeddingsTable, ...
    "EEGClass", ...
    'Learners', template, ...
    'Coding', 'onevsone');

Show the final test performance of the trained encoder. The recall and precision performance for all three classes is excellent. The learned feature embeddings provide nearly 100% recall and precision for the normal (0), pre-seizure (1), and seizure classes (2). Each embedding represents a reduction in the input size from 4097 samples to 256 samples.

predLabelsFinal = predict(classificationSVM,finalEmbeddingsTable);
testAccuracyFinal = sum(predLabelsFinal == testLabelsPS)/numel(testLabelsPS)*100
testAccuracyFinal = 100
hf = figure;
set(gca,'Title','Confusion Chart -- Trained Embeddings')

For completeness, test the cross-validation accuracy of the feature embeddings. Use five-fold cross validation.

partitionedModel = crossval(classificationSVM, 'KFold', 5);
[validationPredictions, validationScores] = kfoldPredict(partitionedModel);
validationAccuracy =  ...
    (1 - kfoldLoss(partitionedModel, 'LossFun', 'ClassifError'))*100
validationAccuracy = single

The cross-validation accuracy is also excellent at near 100%. Note that we have used all the 256 embeddings in the SVM model, but the embeddings returned by the encoder are always amenable to further reduction by using feature selection techniques such as neighborhood component analysis, minimum redundancy maximum relevance (MRMR), or principal component analysis. See Introduction to Feature Selection (Statistics and Machine Learning Toolbox) for more details.


In this example, a time-frequency convolutional network was used as the basis for learning feature embeddings using a deep metric model. Specifically, the normalized temperature-controlled cross-entropy loss with cosine similarities was used to obtain the embeddings. The embeddings were then used with a SVM with a Gaussian kernel to achieve near perfect test performance. There are a number of ways this deep metric network can be optimized which are not explored in this example. For example, the size of the embeddings can likely be reduced further without affecting performance while achieving further dimensionality reduction. Additionally, there are a large number of similarity (metrics) measures, loss functions, regularizers, and reducers which are not explored in this example. Finally, the resulting embeddings are compatible with any machine learning algorithm. An SVM was used in this example, but you can explore the feature embeddings in the Classification Learner app and may find that another classification algorithm is more robust for your application.


[1] Andrzejak, Ralph G., Klaus Lehnertz, Florian Mormann, Christoph Rieke, Peter David, and Christian E. Elger. "Indications of Nonlinear Deterministic and Finite-Dimensional Structures in Time Series of Brain Electrical Activity: Dependence on Recording Region and Brain State." Physical Review E 64, no. 6 (2001).

[2] Bhattacharyya, Abhijit, Ram Pachori, Abhay Upadhyay, and U. Acharya. "Tunable-Q Wavelet Transform Based Multiscale Entropy Measure for Automated Classification of Epileptic EEG Signals." Applied Sciences 7, no. 4 (2017): 385.

[3] Chen, Ting, Simon Kornblith, Mohammed Norouzi, and Geoffrey Hinton. "A Simple Framework for Contrastive Learning of Visual Representations." (2020).

[4] He, Kaiming, Fan, Haoqi, Wu, Yuxin, Xie, Saining, Girschick, Ross. "Momentum Contrast for Unsupervised Visual Representation Learning." (2020).

[6] Musgrave, Kevin. "PyTorch Metric Learning"

[7] Türk, Ömer, and Mehmet Siraç Özerdem. “Epilepsy Detection by Using Scalogram Based Convolutional Neural Network from EEG Signals.” Brain Sciences 9, no. 5 (2019): 115.

[8] Van den Oord, Aaron, Li, Yazhe, and Vinyals, Oriol. "Representation Learning with Contrastive Predictive Coding." (2019).

function [grads,loss,state] = modelGradcontrastiveLoss(net,X,T)
% This function is only for use in the "Time-Frequency Feature Embedding
% with Deep Metric Learning" example. It may change or be removed in a
% future release.

% Copyright 2022, The Mathworks, Inc.
[y,state] = net.forward(X);
loss = contrastiveLoss(y,T);
grads = dlgradient(loss,net.Learnables);
loss = double(gather(extractdata(loss)));

function [out,info] = helperReadData(x,info)
% This function is only for use in the "Time-Frequency Feature Embedding
% with Deep Metric Learning" example. It may change or be removed in a
% future release.

% Copyright 2022, The Mathworks, Inc.
mu = mean(x,2);
stdev = std(x,1,2);
z = (x-mu)./stdev;
out = {z,info.MemberName};

function [dlX,dlY] = processMB(Xcell,Ycell)
% This function is only for use in the "Time-Frequency Feature Embedding
% with Deep Metric Learning" example. It may change or be removed in a
% future release.

% Copyright 2022, The Mathworks, Inc.
Xcell = cellfun(@(x)reshape(x,1,1,[]),Xcell,'uni',false);
Ycell = cellfun(@(x)str2double(x),Ycell,'uni',false);
dlX = cat(2,Xcell{:});
dlY = cat(1,Ycell{:});

function testFeatureTable = helperEmbedTestFeatures(net,testdata,testlabels)
% This function is only for use in the "Time-Frequency Feature Embedding
% with Deep Metric Learning" example. It may change or be removed in a
% future release.

% Copyright 2022, The Mathworks, Inc.
testFeatures = zeros(length(testlabels),256,'single');
for ii = 1:length(testdata)
    yhat = predict(net,dlarray(reshape(testdata{ii},1,1,[]),'CBT'));
    yhat= extractdata(gather(yhat));
    testFeatures(ii,:) = yhat;
testFeatureTable = array2table(testFeatures);
testFeatureTable = addvars(testFeatureTable,testlabels,...

function loss = contrastiveLoss(features,targets)
% This function is for is only for use in the "Time-Frequency Feature
% Embedding with Deep Metric Learning" example. It may change or be removed
% in a future release.
% Replicates code in PyTorch Metric Learning 
% Python algorithms due to Kevin Musgrave

% Copyright 2022, The Mathworks, Inc. 
    loss = infoNCE(features,targets);

function loss = infoNCE(embed,labels)
    ref_embed = embed;
    [posR,posC,negR,negC] = convertToPairs(labels);
    dist = cosineSimilarity(embed,ref_embed);
    loss = pairBasedLoss(dist,posR,posC,negR,negC);

function [posR,posC,negR,negC] = convertToPairs(labels)
    Nr = length(labels);
    % The following provides a logical matrix which indicates where
    % the corresponding element (i,j) of the covariance matrix of
    % features comes from the same class or not. At each (i,j)
    % coming from the same class we have a 1, at each (i,j) from a
    % different class we have 0. Of course the diagonal is 1s.
    labels = stripdims(labels);
    matches = (labels == labels');
    % Logically negate the matches matrix to obtain differences.
    differences = ~matches;
    % We negate the diagonal of the matches matrix to avoid biasing
    % the learning. Later when we identify the positive and
    % negative indices, these diagonal elements will not be picked
    % up.
    matches(1:Nr+1:end) = false;
    [posR,posC,negR,negC] = getAllPairIndices(matches,differences);


function dist = cosineSimilarity(emb,ref_embed)
    emb = stripdims(emb);
    ref_embed = stripdims(ref_embed);
    normEMB = emb./sqrt(sum(emb.*emb,1));
    normREF = ref_embed./sqrt(sum(ref_embed.*ref_embed,1));
    dist = normEMB'*normREF;

function loss = pairBasedLoss(dist,posR,posC,negR,negC)
    if any([isempty(posR),isempty(posC),isempty(negR),isempty(negC)])
        loss = dlarray(zeros(1,1,'like',dist));
    Temperature = 0.07;
    dtype = underlyingType(dist);
    idxPos = sub2ind(size(dist),posR,posC);
    pos_pair = dist(idxPos);
    pos_pair = reshape(pos_pair,[],1);
    idxNeg = sub2ind(size(dist),negR,negC);
    neg_pair = dist(idxNeg);
    neg_pair = reshape(neg_pair,[],1);
    pos_pair = pos_pair./Temperature;
    neg_pair = neg_pair./Temperature;
    n_per_p = negR' == posR;
    neg_pairs = neg_pair'.*n_per_p;
    neg_pairs(n_per_p==0) = -realmax(dtype);
    maxNeg = max(neg_pairs,[],2);
    maxPos = max(pos_pair,[],2);
    maxVal = max(maxPos,maxNeg);
    numerator = exp(pos_pair-maxVal);
    denominator = sum(exp(neg_pairs-maxVal),2)+numerator;
    logexp = log((numerator./denominator)+realmin(dtype));
    loss = mean(-logexp,'all');

function [posR,posC,negR,negC] = getAllPairIndices(matches,differences)
    % Here we just get the row and column indices of the anchor
    % positive and anchor negative elements.
    [posR, posC] = find(extractdata(matches));
    [negR,negC] = find(extractdata(differences));

See Also




Related Topics