I'm using VIT transformer in my code. How to convert the output of 1D layer of VIT into 2D with format SSCB?
8 个评论
Hi Abdulrahman,
I cannot execute the code because visionTransformer requires Computer Vision Toolbox. To illustrate resolving your error, I had to adapt your given code from mathworks for input dimensions of 24 x 24 x 768, by adjusting the reshaping and processing steps accordingly. Here is update the code step by step:
% Get Vision Transformer model
net = visionTransformer;
% Create dummy input
input = dlarray(rand(24,24,768),'SSCB');
% Obtain output embedding from the last LayerNormalizationLayer
out = forward(net, input, 'Outputs', 'encoder_norm');
% Reshape output patch embedding
out = reshapePatchEmbedding(out);
function out = reshapePatchEmbedding(in)
% Remove output embedding corresponding to the class token from the input
out = in(2:end,:,:);
% Reshape the resulting embedding to the input format
WH = sqrt(size(out, 1));
C = size(out, 2);
out = reshape(out, WH, WH, C, []); % Shape is W x H x C x N
out = permute(out, [2, 1, 3, 4]); % Shape is H x W x C x N
% Convert to formatted dlarray
out = dlarray(out, 'SSCB');
end
So, in my updated code snippet, I changed the dummy input dimensions to 24 x 24 x 768 to match the specified input size. The reshaping function reshapePatchEmbedding has been adjusted to handle the new dimensions correctly. Please let me know if this helps resolve your issue.
回答(2 个)
1 个评论
Hi Abdulrahman,
I cannot execute the code because visionTransformer requires Computer Vision Toolbox. To illustrate resolving your error, I had to adapt your given code from mathworks for input dimensions of 24 x 24 x 768, by adjusting the reshaping and processing steps accordingly. Here is update the code step by step:
% Get Vision Transformer model
net = visionTransformer;
% Create dummy input
input = dlarray(rand(24,24,768),'SSCB');
% Obtain output embedding from the last LayerNormalizationLayer
out = forward(net, input, 'Outputs', 'encoder_norm');
% Reshape output patch embedding
out = reshapePatchEmbedding(out);
function out = reshapePatchEmbedding(in)
% Remove output embedding corresponding to the class token from the input
out = in(2:end,:,:);
% Reshape the resulting embedding to the input format
WH = sqrt(size(out, 1));
C = size(out, 2);
out = reshape(out, WH, WH, C, []); % Shape is W x H x C x N
out = permute(out, [2, 1, 3, 4]); % Shape is H x W x C x N
% Convert to formatted dlarray
out = dlarray(out, 'SSCB');
end
So, in my updated code snippet, I changed the dummy input dimensions to 24 x 24 x 768 to match the specified input size. The reshaping function reshapePatchEmbedding has been adjusted to handle the new dimensions correctly. Please let me know if this helps resolve your issue.
另请参阅
类别
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!发生错误
由于页面发生更改,无法完成操作。请重新加载页面以查看其更新后的状态。
您也可以从以下列表中选择网站:
如何获得最佳网站性能
选择中国网站(中文或英文)以获得最佳网站性能。其他 MathWorks 国家/地区网站并未针对您所在位置的访问进行优化。
美洲
- América Latina (Español)
- Canada (English)
- United States (English)
欧洲
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom(English)
亚太
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)