- Convert Inputs to dlarray: Ensure tw and xw are of type dlarray before passing them into the function for gradient computation.
- Define a Gradient Computation Function: The computeGradients function performs the necessary operations and computes the gradients. The sum operation is used to ensure that dlgradient has a scalar to differentiate against.
- Use dlfeval: dlfeval is used to call computeGradients, which ensures that operations inside computeGradients are properly traced by the automatic differentiation engine.
Face Difficulty when converting tensorflow model to Matlab
2 次查看(过去 30 天)
显示 更早的评论
I have a part of tensorflow code that I need to translate to matlab, but fail to do that. I have checked deep learning toolbox and unable to resolve the issue. If someone can help me this question, it is very helpful.
My tensorflow code (Python) is the following:
def get_r(model,tw,xw,a_data_n,mean_a,mean_u,kP,dim1,dim2,w1,stdt,stdx):
% A tf.GradientTape is used to compute derivatives in TensorFlow
with tf.GradientTape(persistent=True) as tape: % This makes you record the gradients on the tape for the parameters defined
tape.watch(tw) % This is needed to 'follow' the time, for automatic differentiation with respect to time
tape.watch(xw) % This is needed to 'follow' the position, for automatic differentiation with respect to position
a,u,p = model.net_u(tw,xw)
Px = tape.gradient(p, xw)
At = tape.gradient(a, tw)
ux = tape.gradient(u, xw)
ut = tape.gradient(u, tw)
My Matlab Code is the following: (notice that model.net_u input cannot accept dlarray format, dlarrya has to be done after model.net_u fucntion)
function get_r(model, tw, xw, a_data_n, mean_a, mean_u, kP, dim1, dim2, w1, stdt, stdx)
% Compute derivatives using the MATLAB automatic differentiation functionality
% Run the model
[a, u, p] = model.net_u(tw, xw);
a = dlarray(a);
tw = dlarray(tw);
xw = dlarray(xw);
% Compute gradients by iterating over each element
At = dlgradient(a, tw);
Px = dlgradient(p, xw);
ux = dlgradient(u, xw);
ut = dlgradient(u, tw);
end
Also, my a, u, p variable all have the shape 39600 * 1
My error message is
Error using dlarray/dlgradient (line 105)
Value to differentiate is not traced. It must be a traced real dlarray scalar. Use dlgradient inside a function called by dlfeval to
trace the variables.
Error in get_r (line 12)
At = dlgradient(sum(a, 'all'), tw);
Can anyone point out how can I improve the code by adding dlfeval and other codes as well.
Additionally, sometimes when I add dlfeval, it will create the following error. What does this error mean?
Error using deep.internal.dlfevalWithNestingCheck (line 14)
Nested dlfeval calls are not supported. To compute higher derivatives, set the 'EnableHigherDerivatives' option of the dlgradient
function to true.
Error in dlfeval (line 31)
[varargout{1:nargout}] = deep.internal.dlfevalWithNestingCheck(fun,varargin{:});
Thanks for all suggestions!
0 个评论
回答(1 个)
Abhishek Kumar Singh
2024-7-28
The error messages you're encountering suggest that the context in which you're calling dlgradient isn't quite right. In MATLAB, dlarray should be used to encapsulate data you want to differentiate, and dlgradient should be used inside a function called by dlfeval.
Here's how you can rewrite your MATLAB function to correctly compute gradients, including addressing the nested dlfeval issue:
function get_r(model, tw, xw, a_data_n, mean_a, mean_u, kP, dim1, dim2, w1, stdt, stdx)
% Convert inputs to dlarray if they are not already
tw = dlarray(tw);
xw = dlarray(xw);
% Define a function for computing gradients
function [At, Px, ux, ut] = computeGradients(tw, xw)
% Run the model
[a, u, p] = model.net_u(tw, xw);
% Compute gradients
At = dlgradient(sum(a, 'all'), tw);
Px = dlgradient(sum(p, 'all'), xw);
ux = dlgradient(sum(u, 'all'), xw);
ut = dlgradient(sum(u, 'all'), tw);
end
% Use dlfeval to compute gradients
[At, Px, ux, ut] = dlfeval(@computeGradients, tw, xw);
end
Here's a brief explantion for the above code snippet:
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Custom Training Loops 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!