Complex number gradient using 'dlgradient' in conjunction with neural networks
7 views (last 30 days)
Show older comments
Dr. Veerababu Dharanalakota
on 7 Apr 2023
Commented: Walter Roberson
on 7 Apr 2023
Hello All,
I am trying to find the gradient of a function
, where C is a complex-valued constant,
is a feedforward neural network, x is the input vector (real-valued) and θ are the parameters (real-valued). The output of the neural network is a real-valued array. However, due to the presence of complex constant C, the function f is becoming a complex-valued. I would like to find its gradient
with respect to the input vector x.
![](https://www.mathworks.com/matlabcentral/answers/uploaded_files/1348249/image.png)
![](https://www.mathworks.com/matlabcentral/answers/uploaded_files/1348254/image.png)
![](https://www.mathworks.com/matlabcentral/answers/uploaded_files/1348259/image.png)
I tried to follow the method mentioned in https://in.mathworks.com/help/deeplearning/ref/dlarray.dlgradient.html which is given below (modified)
clc;
clear all;
x = linspace(1,10,5); % Real-valued array
x = dlarray(x,"CB"); % Converting to deeplearning array
[y, grad] = dlfeval(@gradFun,x);
grad = extractdata(grad)
% Complex-function
function y = complexFun(x)
y = (2+3j)*x.^2;
end
% Function to calculate complex gradient
function [y,grad] = gradFun(x)
y = complexFun(x);
y = real(y);
grad = dlgradient(sum(y,"all"),x,'EnableHigherDerivatives',true);
end
The method is successfully calculating the gradient of a complex number
(of course, giving conjugate output). I tried implementing the same by replacing the real-valued function
with
. When I did this, I am encoutering the following error
![](https://www.mathworks.com/matlabcentral/answers/uploaded_files/1348264/image.png)
![](https://www.mathworks.com/matlabcentral/answers/uploaded_files/1348269/image.png)
![](https://www.mathworks.com/matlabcentral/answers/uploaded_files/1348274/image.png)
"Encountered complex value when computing gradient with respect to an output of fullyconnect. Convert all outputs of fullyconnect to real".
I would be grateful if anyone could show a way to fix the error and calculate the gradients.
Thank you,
Dr. Veerababu Dharanalakota
0 Comments
Accepted Answer
Walter Roberson
on 7 Apr 2023
The derivative of C*f(x) can be calculated using the chain rule for multiplication: dC/dx*f(x) + C*df/dx. But when C is constant then no matter whether it is real or complex valued, dC/dx is 0. Therefore the derivative of C*f(x) is C*df/dx. The same logic applies to second derivatives.
Therefore the gradient of C*f(x) is C times the gradient of f(x). And if f(x) is real valued as indicated, and C is complex valued then unless the gradient is 0 it follows that the gradient of C*f(x) will be complex valued. Which dlgradient will refuse to work with.
So take the dlgradient of f(x) and multiply the result by C. That should at least postpone the problem.
2 Comments
More Answers (0)
See Also
Categories
Find more on Custom Training Loops in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!