Computing Hessian by dlgradient

14 次查看(过去 30 天)
MAHSA YOUSEFI
MAHSA YOUSEFI 2022-2-4
回答: Yash 2023-12-18
Hi every one.
I am using a training loop for my model in which gradients are computing by dlgradient. As you know, dlgradient (through dlfeval) returns a TABLE in which the layers, parameters (weights and bias) and gradients' values are stored. Also, we know that dlgradient accepts "loss" as a SCALLER and dlnet.Learnables, data samples dlX and targets dlY for these computations. I am interested in computing Hesseian for a small network using dlX and dlY. In fact I am going to compute a sub-sampled Hessian if I uses mini-batch dlX. (SO, I do not have problem for storing this matrix then!). However, I do not know how I apply dlgradient one more time for computing Hessian. If someone knows, I would thankfull him/her.

回答(1 个)

Yash
Yash 2023-12-18
Hi Mahsa,
To compute the Hessian using dlgradient, you can use the same approach as for computing gradients. However, instead of computing gradients for each parameter, you need to compute the second-order partial derivatives for each pair of parameters. You can use the dlgradient function twice, once for each parameter, and then compute the Hessian matrix using the second-order partial derivatives.
Here is a code snippet you can use as a reference to understand what I want to convey:
Assuming that dlnet is your network, dlX and dlY are your data samples and targets, and mse is your loss function.
% Define the loss function
loss = @(dlY, Y) mse(dlY, Y);
% Compute the gradients for each parameter
[grads, ~] = dlgradient(dlnet, dlX, 'Output', dlY, 'LossFunction', loss);
% Compute the Hessian matrix
H = zeros(numel(dlnet.Learnables), numel(dlnet.Learnables));
for i = 1:numel(dlnet.Learnables)
for j = i:numel(dlnet.Learnables)
% Compute the second-order partial derivative
hessian = dlgradient(grads(i), dlnet.Learnables(j), 'Output', dlY, 'LossFunction', loss);
H(i,j) = hessian;
H(j,i) = hessian;
end
end
The grads variable contains the gradients for each parameter, and the H variable contains the Hessian matrix.
Hope this helps!

类别

Help CenterFile Exchange 中查找有关 Sequence and Numeric Feature Data Workflows 的更多信息

产品


版本

R2021a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by