Complex number gradient using 'dlgradient' in conjunction with neural networks

16 次查看(过去 30 天)
Hello All,
I am trying to find the gradient of a function , where C is a complex-valued constant, is a feedforward neural network, x is the input vector (real-valued) and θ are the parameters (real-valued). The output of the neural network is a real-valued array. However, due to the presence of complex constant C, the function f is becoming a complex-valued. I would like to find its gradient with respect to the input vector x.
I tried to follow the method mentioned in https://in.mathworks.com/help/deeplearning/ref/dlarray.dlgradient.html which is given below (modified)
clc;
clear all;
x = linspace(1,10,5); % Real-valued array
x = dlarray(x,"CB"); % Converting to deeplearning array
[y, grad] = dlfeval(@gradFun,x);
grad = extractdata(grad)
grad =
4.0000 - 6.0000i 13.0000 -19.5000i 22.0000 -33.0000i 31.0000 -46.5000i 40.0000 -60.0000i
% Complex-function
function y = complexFun(x)
y = (2+3j)*x.^2;
end
% Function to calculate complex gradient
function [y,grad] = gradFun(x)
y = complexFun(x);
y = real(y);
grad = dlgradient(sum(y,"all"),x,'EnableHigherDerivatives',true);
end
The method is successfully calculating the gradient of a complex number (of course, giving conjugate output). I tried implementing the same by replacing the real-valued function with . When I did this, I am encoutering the following error
"Encountered complex value when computing gradient with respect to an output of fullyconnect. Convert all outputs of fullyconnect to real".
I would be grateful if anyone could show a way to fix the error and calculate the gradients.
Thank you,
Dr. Veerababu Dharanalakota

采纳的回答

Walter Roberson
Walter Roberson 2023-4-7
The derivative of C*f(x) can be calculated using the chain rule for multiplication: dC/dx*f(x) + C*df/dx. But when C is constant then no matter whether it is real or complex valued, dC/dx is 0. Therefore the derivative of C*f(x) is C*df/dx. The same logic applies to second derivatives.
Therefore the gradient of C*f(x) is C times the gradient of f(x). And if f(x) is real valued as indicated, and C is complex valued then unless the gradient is 0 it follows that the gradient of C*f(x) will be complex valued. Which dlgradient will refuse to work with.
So take the dlgradient of f(x) and multiply the result by C. That should at least postpone the problem.
  2 个评论
Dr. Veerababu Dharanalakota
Thank you, Walter.
I am able to calculate the gradient as well as and constructed a function , where α is a real-valued constant. Now, I would like to minimize g with respect to the parameters θ. Since g is a complex-valued function, I split it into real and imaginary parts, and casted the loss function and its gradient with respect to the parameters θ as follows
g = fxx+alpha*f;
gr = real(g); % Real-part of g
gi = imag(g); % Imaginary-part of g
zeroTarget_r = zeros(size(gr),"like",gr); % Zero targets for the real-part
loss_r = l2loss(gr, zeroTarget_r); % Real-part loss function
zeroTarget_i = zeros(size(gi),"like",gi); % Zero targets for the imaginary-part
loss_i = l2loss(gi, zeroTarget_i); % Imaginary-part loss function
loss = loss_r+loss_i; % Total loss function (real-valued)
gradients = dlgradient(loss,parameters);
The loss function as wells as the parameters are real-valued. Ideally, I should be able to calculate the gradients. However, it is thowing a similar error (instead of outputs, now it is inputs)
"Encountered complex value when computing gradient with respect to an input to fullyconnect. Convert all inputs to fullyconnect to real".
I checked indivial loss values and the parameter values. They are purely real.
I hope your insight might help to resolve this issue as well.

请先登录,再进行评论。

更多回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Get Started with Statistics and Machine Learning Toolbox 的更多信息

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by