Why do I get "Encountered complex value" error when training a custom network with the "fft" function?

I'm training a neural network with "fullyConnectedLayer" and a custom loss function that has the "fft" function. "fft" should be supported for "dlarray" operations, but when I call the "dlgradient" function to calculate the gradients, I get an error:
Error using dlarray/dlgradient
Encountered complex value when computing gradient with respect to an output of fullyconnect. Convert all outputs of fullyconnect to real.

 Accepted Answer

When the "fft" function is used in training, it might introduce complex-valued gradients during back propagation. Such complex gradients are not supported by "fullyConnectedLayer". In order to ensure that the gradients are always real-valued, in the custom loss function you can use the "real" function to cast the data as real values, before sending them to the "fft" function. For example:
function [loss, gradients] = testLoss(net, input, target)
y = net.forward(input);
y = stripdims(y);
y = real(y); % <-- Cast the "y" variable as real-valued before calling "fft" on it
yfft = fft(y,4096);
yabs = abs(yfft);
loss = sum(yabs,'all');
gradients = dlgradient(loss,net.Learnables);
end

More Answers (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!