attention
Syntax
Description
The attention operation focuses on parts of the input using weighted multiplication operations.
Examples
Apply Attention Operation
Specify the sizes of the queries, keys, and values.
querySize = 100; valueSize = 120; numQueries = 64; numValues = 80; numObservations = 32;
Create random arrays containing the queries, keys, and values. For the queries, specify the dlarray
format "CBT"
(channel, batch, time).
queries = dlarray(rand(querySize,numObservations, numQueries),"CBT");
keys = dlarray(rand(querySize,numObservations, numValues));
values = dlarray(rand(valueSize,numObservations, numValues));
Specify the number of attention heads.
numHeads = 5;
Apply the attention operation.
[Y,weights] = attention(queries,keys,values,numHeads);
View the sizes and format of the output.
size(Y)
ans = 1×3
120 32 64
dims(Y)
ans = 'CBT'
View the sizes and format of the weights.
size(weights)
ans = 1×4
80 64 5 32
dims(weights)
ans = 0×0 empty char array
Create Multihead Self Attention Function
You can use the attention
function to implement the multihead self attention operation [1] that focuses on parts of the input.
Create the multiheadSelfAttention
function, listed in the Multihead Self Attention Function section of the example. The multiheadSelfAttention
function takes as input the data X
, the number of heads, and the learnable weights for the queries, keys, values, and output data, and returns the multihead attention values.
The X
input must be an unformatted dlarray
object, where the first dimension corresponds to the input channels, the second dimension corresponds to the time or spatial dimension, and the third dimension corresponds to the batch dimension.
Create an array of sequence data.
numChannels = 10; numObservations = 128; numTimeSteps = 100; X = rand(numChannels,numObservations,numTimeSteps); X = dlarray(X); size(X)
ans = 1×3
10 128 100
Specify the number of heads for multihead attention.
numHeads = 8;
Initialize the learnable parameters for multihead attention.
The learnable query, key, and value weights must be
(numChannels*numHeads)
-by-numChannels
arrays.The learnable output weights must be a
(numChannels*numHeads)
-by-(numChannels*numHeads)
array.
outputSize = numChannels*numHeads; WQ = rand(outputSize,numChannels); WK = rand(outputSize,numChannels); WV = rand(outputSize,numChannels); WO = rand(outputSize,outputSize);
Apply the multihead self attention operation.
Y = multiheadSelfAttention(X,numHeads,WQ,WK,WV,WO);
View the size of the output. The output has size (numChannels*numHeads)
-by-numObservations
-by-(numTimeSteps)
.
size(Y)
ans = 1×3
80 128 100
Multihead Self Attention Function
The multiheadSelfAttention
function takes as input the data X
, the number of heads, and the learnable weights for the queries, keys, values, and output data, and returns the multihead attention values.
The
X
input must be an unformatteddlarray
object, where the first dimension corresponds to the input channels, the second dimension corresponds to the time or spatial dimension, and the third dimension corresponds to the batch dimension.The learnable query, key, and value weight matrices are
(numChannels*numHeads)
-by-numChannels
matrices.The learnable output weights matrix is a
(numChannels*numHeads)
-by-(numChannels*numHeads)
matrix.
function Y = multiheadSelfAttention(X,numHeads,WQ,WK,WV,WO) queries = pagemtimes(WQ,X); keys = pagemtimes(WK,X); values = pagemtimes(WV,X); A = attention(queries,keys,values,numHeads,DataFormat="CTB"); Y = pagemtimes(WO,A); end
Create Luong Attention Function
You can use the attention
function to create a function that applies the Luong attention operation to its input. Create the luongAttention
function, listed at the end of the example, that applies the Luong attention operation.
Specify the array sizes.
numHiddenUnits = 100; latentSize = 16;
Create random arrays containing the input data.
hiddenState = dlarray(rand(numHiddenUnits,1)); Z = dlarray(rand(latentSize,1)); weights = dlarray(rand(numHiddenUnits,latentSize));
Apply the luongAttention
function.
[context,attentionScores] = luongAttention(hiddenState,Z,weights);
View the sizes of the outputs.
size(context)
ans = 1×2
16 1
size(attentionScores)
ans = 1×2
1 1
Luong Attention Function
The luongAttention
function returns the context vector and attention scores according to the Luong "general" scoring [2]. This operation is equivalent to dot-product attention with queries, keys, and values specified as the hidden state, the weighted latent representation, and the latent representation, respectively.
function [context,attentionScores] = luongAttention(hiddenState,Z,weights) numHeads = 1; queries = hiddenState; keys = pagemtimes(weights,Z); values = Z; [context,attentionScores] = attention(queries,keys,values,numHeads,Scale=1,DataFormat="CBT"); end
Input Arguments
queries
— Queries
dlarray
object
Queries, specified as a dlarray
object.
queries
can have at most one "S"
(spatial)
or "T"
(time) dimension. Any dimensions in
queries
labeled "U"
(unspecified) must be
singleton. If queries
is an unformatted dlarray
object, then specify the data format using the DataFormat
option.
The size of the "C"
(channel) dimension in keys
must
match the size of the corresponding dimension in queries
.
The size of the "B"
(batch) dimension in queries
, keys
, and values
must match.
keys
— Keys
dlarray
object | numeric array
Keys, specified as a dlarray
object or a numeric array.
If keys
is a formatted dlarray
object, then
its format must match the format of queries
. If
keys
is not a formatted dlarray
object, then the
function uses the same format as queries
.
The size of any "S"
(spatial) or "T"
(time) dimensions in keys
must match the size of the corresponding dimension in values
.
The size of the "C"
(channel) dimension in keys
must
match the size of the corresponding dimension in queries
.
The size of the "B"
(batch) dimension in queries
, keys
, and values
must match.
values
— Values
dlarray
object | numeric array
Values, specified as a dlarray
object or a numeric array.
If values
is a formatted dlarray
object, then
its format must match the format of queries
. Otherwise, the
function uses the same format as queries
.
The size of any "S"
(spatial) or "T"
(time) dimensions in keys
must match the size of the corresponding dimension in values
.
The size of the "B"
(batch) dimension in queries
, keys
, and values
must match.
Name-Value Arguments
Specify optional pairs of arguments as
Name1=Value1,...,NameN=ValueN
, where Name
is
the argument name and Value
is the corresponding value.
Name-value arguments must appear after other arguments, but the order of the
pairs does not matter.
Before R2021a, use commas to separate each name and value, and enclose
Name
in quotes.
Example: attention(queries,keys,values,numHeads,DataFormat="CBT")
applies the attention operation for unformatted data and specifies the data format
"CBT"
(channel, batch, time).
DataFormat
— Dimension order of unformatted data
character vector | string scalar
Dimension order of unformatted input data, specified as a character vector or
string scalar FMT
that provides a label for each dimension of the
data.
When you specify the format of a dlarray
object, each character provides a
label for each dimension of the data and must be one of these options:
"S"
— Spatial"C"
— Channel"B"
— Batch (for example, samples and observations)"T"
— Time (for example, time steps of sequences)"U"
— Unspecified
You can use the labels "C"
and "B"
at most
once and one dimension labeled either "S"
or
"T"
.
You must specify DataFormat
when the input data is not a
formatted dlarray
object.
Data Types: char
| string
Scale
— Multiplicative factor for scaled dot-product attention
"auto"
(default) | numeric scalar
Multiplicative factor for scaled dot-product attention [1], specified as one of these values:
"auto"
— Multiply the dot-product by , where dk denotes the number of channels in the keys divided by the number of heads.Numeric scalar — Multiply the dot-product by the specified scale factor.
Data Types: single
| double
| char
| string
PaddingMask
— Mask indicating padding values
dlarray
object | logical array | binary-valued numeric array
Mask indicating which elements of the input correspond to padding values,
specified as a dlarray
object, a logical array, or a binary-valued
numeric array.
The function prevents and allows attention to elements of input data key-value
pairs when the corresponding element in PaddingMask
is
0
and 1
, respectively.
If PaddingMask
is a formatted dlarray
object, then its format must match that of keys
. If
PaddingMask
is not a formatted dlarray
object,
then the function uses the same format as keys
. The size of the
"S"
(spatial), "T"
(time), and
"B"
(batch) dimensions in PaddingMask
must
match the size of the corresponding dimensions in keys
and
values
.
The default value is a logical array of ones with the same size as
keys
.
AttentionMask
— Attention mask
"none"
(default) | "causal"
| numeric array | logical array
Attention mask indicating which elements to include when applying the attention operation, specified as one of these values:
"none"
— Do not prevent attention to elements with respect to their positions. IfAttentionMask
is"none"
, then the software prevents attention usingPaddingMask
only."causal"
— Prevent elements in position M in the"S"
(spatial) or"T"
(time) dimension ofqueries
from providing attention to the elements in positions n, where n is greater than M in the corresponding dimension ofkeys
andvalues
. Use this option for auto-regressive models.Logical or numeric array — Prevent attention to elements of
keys
andvalues
when the corresponding element in the array is0
, whereAttentionMask
is a Nk-by-Nq matrix or a Nk-by-Nq-by-numObservations
array, Nk is the size of the"S"
(spatial) or"T"
(time) dimension ofkeys
, Nq is the size of the corresponding dimension inqueries
, andnumObservations
is the size of the"B"
dimension inqueries
.
Data Types: single
| double
| int8
| int16
| int32
| int64
| uint8
| uint16
| uint32
| uint64
| logical
| char
| string
DropoutProbability
— Dropout probability
0
(default) | scalar in the range [0, 1)
Dropout probability for the attention weights, specified as a scalar in the range [0, 1).
Data Types: single
| double
Output Arguments
Y
— Result of attention operation
dlarray
object
Result of attention operation, returned as a dlarray
object.
If queries
is a formatted dlarray
object, then
Y
is a formatted dlarray
object with the same
dimension labels as queries
. The size of the
"C"
(channel) dimension of Y
is the same as
the size of the corresponding dimension in values
. The size of the
"S"
(spatial) or "T"
dimension of
Y
is the same size as the corresponding dimension in
queries
.
If queries
is not a formatted dlarray
object,
then Y
is an unformatted dlarray
object.
weights
— Attention weights
unformatted dlarray
object
Attention weights, returned as an unformatted dlarray
object.
weights
is a
Nk-by-Nq-by-numHeads
-by-numObservations
array, where Nk is the size of the
"S"
(spatial) or "T"
(time) dimension of
keys
, Nq is the size of
the corresponding dimension in queries
, and
numObservations
is the size of the "B"
(batch)
dimension in queries
.
Algorithms
Dot-Product Attention
The attention operation focuses on parts of the input using weighted multiplication operations.
The single-head dot-product attention operation is given by
where:
Q denotes the queries.
K denotes the keys.
V denotes the values.
denotes the scaling factor.
M is a mask array of ones and zeros.
p is the dropout probability
The mask operation includes and excludes the values of the matrix multiplication setting values of the input to for zero-valued mask elements. The mask is the union of the padding and attention masks. The softmax function normalizes the value of the input data across the channel dimension such that it sums to one. The dropout operation sets elements to zero with probability p.
Multihead Self-Attention
The multihead self-attention operation for the input X is given by
where:
h is the number of heads.
WQ is a learnable projection matrix for the queries.
WK is a learnable projection matrix for the keys.
WV is a learnable projection matrix for the values.
WO is a learnable projection matrix for the output.
Each weight matrix is composed of concatenated weight matrices Wi for each head. Each denotes the output of the head operation given by
References
[1] Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, and Illia Polosukhin. "Attention is all you need." Advances in neural information processing systems 30 (December 2017): 6000-6010. https://papers.nips.cc/paper/7181-attention-is-all-you-need.
[2] Luong, Minh-Thang, Hieu Pham, and Christopher D. Manning. "Effective approaches to attention-based neural machine translation." arXiv preprint arXiv:1508.04025 (2015).
Extended Capabilities
GPU Arrays
Accelerate code by running on a graphics processing unit (GPU) using Parallel Computing Toolbox™.
Usage notes and limitations:
When at least one of these input arguments is a
gpuArray
object or adlarray
object with underlying data of typegpuArray
, this function runs on the GPU.queries
keys
values
For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
Version History
Introduced in R2022b
See Also
padsequences
| dlarray
| dlgradient
| dlfeval
| lstm
| gru
| embed
Topics
- Define Custom Training Loops, Loss Functions, and Networks
- Train Network Using Model Function
- Sequence-to-Sequence Translation Using Attention
- Image Captioning Using Attention
- Multilabel Graph Classification Using Graph Attention Networks
- Language Translation Using Deep Learning
- List of Functions with dlarray Support
MATLAB 命令
您点击的链接对应于以下 MATLAB 命令:
请在 MATLAB 命令行窗口中直接输入以执行命令。Web 浏览器不支持 MATLAB 命令。
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)