Effficient Computation of Matrix Gradient
2 次查看(过去 30 天)
显示 更早的评论
Hi,
I am trying to compute the gradient of a matrix-valued function . I have computed the element-wise gradient as and have verified that it is correct numerically (for my purposes of gradient descent).
My MATLAB implementation of the above gradient is:
for p = 1:N
for q = 1:N
gradX(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
end
end
which I have also verified is correct numerically.
However, my issue is that N = 750, so this computation is extremely slow and impractical for gradient descent: on my desktop with 32 GB RAM and an Intel Xeon 3.7 GHz processor, one iteration takes around 10-15 minutes. I expect to need several hundred iterations for convergence.
I was wondering if there is any obvious way I am missing to speed up or parallelize it. I have tried parfor but have not had any luck.
Thank you and I very much appreciate any suggestions.
2 个评论
Bruno Luong
2024-4-9
编辑:Bruno Luong
2024-4-9
Whare is a typical size of w (or AXB)?
btw the first obvious optimization is pre multiply w with AXB.
采纳的回答
Bruno Luong
2024-4-9
The best
N = 200; % 750
gradX_1 = zeros(N,N);
w = rand(N,N);
AXB = rand(N,N)+1i*rand(N);
A = rand(N,N)+1i*rand(N);
B = rand(N,N)+1i*rand(N);
tic
for p = 1:N
for q = 1:N
gradX_1(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
end
end
t1=toc
% Method 3
tic
C = w .* AXB;
gradX = A' * C * B';
t2=toc
err = norm(gradX(:)-gradX_1(:),'inf') / norm(gradX_1(:))
fprintf('New code version 3 is %g faster\n', t1/t2)
更多回答(1 个)
Bruno Luong
2024-4-9
I propose this, and time testing for N = 200;
N = 200; % 750
gradX_1 = zeros(N,N);
w = rand(N,N);
AXB = rand(N,N)+1i*rand(N);
A = rand(N,N)+1i*rand(N);
B = rand(N,N)+1i*rand(N);
tic
for p = 1:N
for q = 1:N
gradX_1(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
end
end
t1=toc
gradX = zeros(N,N);
tic
C = w .* AXB;
C = reshape(C,1,[]);
for p = 1:N
Ap = A(:,p);
for q = 1:N
AB = Ap * B(q,:);
AB = reshape(AB,1,[]);
gradX(p,q) = C * AB';
end
end
t2=toc
fprintf('New code version 1 is %g faster\n', t1/t2)
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Logical 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!