How the number of parameters is calculated if multihead self attention layer is used in a CNN model?

74 次查看(过去 30 天)
I have run the example in the following link in two cases:
Case 1: NumHeads = 4, NumKeyChannels = 784 Case 2: NumHeads = 8, NumKeyChannels = 392 Note that:
4x784 = 8x392 = 3136 (size of input feature vector to the attention layer). I have calculated the number of model parameters in the two cases and I got the following: 9.8 M for the first case, and 4.9 M for the second case.
I expected the number of learnable parameters to be the same. However, MATLAB reports different parameter counts.
My understanding from research papers is that the total parameters should not scale with how input is split across heads. The number of parameters should be the same as long as the input feature vector is the same, and the product of the number of heads by the size of each head (number of channels) is equal to the input size.
Why does MATLAB’s selfAttentionLayer produce different parameter counts for these two configurations? Am I misinterpreting how the layer is implemented in this toolbox?
  11 个评论
Hana Ahmed
Hana Ahmed 2025-9-5,18:01
Thank you very much for your reply. A final question please. If we have 8 parallel heads, each head has three projection matrices, do we expect to see 24 projection matrices in the work space? or only the three matrices of only one head?
Umar
Umar 2025-9-6,4:37

Hi @Hana Ahmed,

Even though each of the 8 heads conceptually has its own Q/K/V matrices, MATLAB stores them as three concatenated matrices. Each matrix is sliced internally to compute per-head projections, which is why you see only 3 matrices in the workspace instead of 24.

Script

close all; clear all; clc
numHeads = 8; d_k = 64; inputDim = 512; batchSize = 10;
X = randn(batchSize, inputDim);
% Concatenated projection matrices
W_Q = randn(inputDim, numHeads*d_k);
Q_full = X * W_Q;           % [10 x 512]
% Slice per head
Q_heads = zeros(batchSize, numHeads, d_k);
for i = 1:numHeads
  idx = (i-1)*d_k + 1 : i*d_k;
  Q_heads(:, i, :) = Q_full(:, idx);
end
disp(size(Q_full)) 
disp(size(Q_heads)) 

Results:

Explanation:

  • `Q_full` shows all 8 heads concatenated.
  • `Q_heads` shows per-head slices (64 channels each).
  • This is mathematically equivalent to having separate matrices per head and is memory-efficient.

请先登录,再进行评论。

回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Get Started with Polyspace Products for Ada 的更多信息

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by