How to add svd(singular value decomposition) in a custom layer
6 次查看(过去 30 天)
显示 更早的评论
I am seeing this site https://www.mathworks.com/help/deeplearning/ug/define-custom-deep-learning-layer.html and trying to define a custom layer which contains singular value decomposition function. However, when I check validity of layer, it show this following error message:
" The function 'predict' threw an error:
Error using svd
First input must be single or double."
Is this because that svd is not supported data dype of dlarray?
Is there any ways to solve this problem?
0 个评论
采纳的回答
Christine Tobler
2019-12-3
Yes, this is because SVD is not supported for dlarray. For its first release in R2019b, dlarray supports about 80 basic methods, you can see them listed here. It would help us prioritize which new functions would be most helpful to add if you could give us some details about what you are using the SVD for.
Instead of using dlarray in the predict method, you can use the SVD if you implement a forward and backward method for your custom layer.
更多回答(1 个)
Christine Tobler
2019-12-4
When dlarray supports a function, this means, most of all, that it supports automatic differentiation of this function - which for the custom layer corresponds to having a backward method.
It's possible to use extractdata on the input to SVD, the only problem is that U, S, and V will then be treated as constants w.r.t. the trainable parameters, so any dependence of the input of SVD on the trainable parameters will not be passed on to U, S and V. If this is okay for your case, that would certainly be simplest.
Otherwise, you would have to write a custom layer, where the backward method computes the derivative of the SVD. If your layer is only doing the SVD of the input, this would look a bit like this (modified based on the example here):
function [Z1, Z2, Z3, memory] = forward(layer, X1)
% (Optional) Forward input data through the layer at training
% time and output the result and a memory value.
%
% Inputs:
% layer - Layer to forward propagate through
% X1, ..., Xn - Input data
% Outputs:
% Z1, ..., Zm - Outputs of layer forward function
% memory - Memory value for custom backward propagation
% Layer forward function for training goes here.
[Z1, Z2, Z3] = svd(A);
end
function dLdX1 = backward(layer, X1, Z1, Z2, Z3, dLdZ1, dLdZ2, dLdZ3, memory)
% (Optional) Backward propagate the derivative of the loss
% function through the layer.
%
% Inputs:
% layer - Layer to backward propagate through
% X1, ..., Xn - Input data
% Z1, ..., Zm - Outputs of layer forward function
% dLdZ1, ..., dLdZm - Gradients propagated from the next layers
% memory - Memory value from forward function
% Outputs:
% dLdX1, ..., dLdXn - Derivatives of the loss with respect to the
% inputs
% dLdW1, ..., dLdWk - Derivatives of the loss with respect to each
% learnable parameter
% Layer backward function goes here.
end
Here the backward function takes A, U, S, V and the derivatives of the loss w.r.t. U, S and V (dL/dU, dL/dS, dL/dV), and needs to compute the derivative of the loss w.r.t. A (dL/dA). I'm not quite sure how to compute this right now, some research would be needed.
2 个评论
Damien T
2021-1-14
编辑:Damien T
2021-1-14
Hi Christine,
The derivative of the SVD is pretty easy to implement, see the reference [1] below. I am personally interested in having the computation of a determinant implemented for the automatic differentiation. The backward pass is also trivial to implement (although it requires a matrix inversion). Right now I'm using your suggestion, i.e. a custom layer with an explicit backward() function, but it brings 2 questions:
- In order to evaluate my custom layer, do I need to encapsulate this layer in a dlgraph/dlnetwork (with an additional input layer) so that I can then calling .predict() on this network ? Is there anything simpler ? It would be nice if a call like y = myLayer.predict(x); would be enough.
- I am using the second-order derivatives of the AD of the R2021a prelease (it's a great addition: I am making great use of them). But does this support layers with custom backward() function ?
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Image Data Workflows 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!