How to create a transformer network for sequence to sequence classification task?

31 次查看(过去 30 天)
I am currently trying to use MATLAB to complete a task of classifying time series using a transformer network. The following is my code, but I cannot solve the error after compiling.
lgraph = [ ...
sequenceInputLayer(InputSize,Name="input")
positionEmbeddingLayer(InputSize,maxPosition,Name="pos-emb");
additionLayer(2, Name="embed_add");
selfAttentionLayer(numHeads,numKeyChannels) % self attention
additionLayer(2,Name="attention_add") % residual connection around attention
layerNormalizationLayer(Name="attention_norm") % layer norm
fullyConnectedLayer(feedforwardHiddenSize) % feedforward part 1
reluLayer % nonlinear activation
fullyConnectedLayer(attentionHiddenSize) % feedforward part 2
additionLayer(2,Name="feedforward_add") % residual connection around feedforward
layerNormalizationLayer() % layer norm
% selfAttentionLayer(numHeads,numKeyChannels,'AttentionMask','causal');
% selfAttentionLayer(numHeads,numKeyChannels);
indexing1dLayer("last")
fullyConnectedLayer(NumClass)
softmaxLayer
classificationLayer];
% Layers = layerGraph(lgraph);
% Layers = connectLayers(Layers,"input","add/in2");
net = dlnetwork(lgraph,Initialize=false);
net = connectLayers(net,"embed_add","attention_add/in2");
net = connectLayers(net,"pos-emb","embed_add/in2");
net = connectLayers(net,"attention_norm","feedforward_add/in2");
% net = connectLayers(net,"encoder1_out","attention2_add/in2");
% net = connectLayers(net,"attention2_norm","feedforward2_add/in2");
net = initialize(net);

回答(1 个)

Prasanna
Prasanna 2024-9-9
Hi veritas,
The error you're encountering is due to the use of the classificationLayer, which is not supported in the context of a dlnetwork‘ object because dlnetwork is designed for custom training loops and does not require an explicit output layer like classificationLayer. Instead, you should handle the loss calculation separately during training.
Here's how you can modify your setup to avoid using classificationLayer:
  • Remove the classificationLayer from your layer graph definition.
  • With dlnetwork, you typically use a custom training loop where you manually compute the loss and update the model parameters.
  • Use a loss function such as cross-entropy directly in your training loop.
To perform the above, you can use thetrainnet function instead of train dlnetwork objects and set the loss function to crossentropy instead. For more references on the functions, refer the following documentation:
Hope this helps!

类别

Help CenterFile Exchange 中查找有关 Image Data Workflows 的更多信息

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by