Using transformer neural network for classification task

19 次查看(过去 30 天)
numChannels = inputSize;
maxPosition = 256;
numHeads = 4;
numKeyChannels = numHeads*32;
layers = [
sequenceInputLayer(numChannels,Name="input")
positionEmbeddingLayer(numChannels, maxPosition, Name="pos-emb");
additionLayer(2, Name="add")
selfAttentionLayer(numHeads,numKeyChannels,'AttentionMask','causal')
selfAttentionLayer(numHeads,numKeyChannels)
indexing1dLayer("last")
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
lgraph = layerGraph(layers);
lgraph = connectLayers(lgraph, "input", "add/in2");
maxEpochs = 100;
miniBatchSize = 32;
learningRate = 0.001;
solver = 'adam';
shuffle = 'every-epoch';
gradientThreshold = 10;
executionEnvironment = "auto"; % chooses local GPU if available, otherwise CPU
options = trainingOptions(solver, ...
'Plots','training-progress', ...
'MaxEpochs', maxEpochs, ...
'MiniBatchSize', miniBatchSize, ...
'Shuffle', shuffle, ...
'InitialLearnRate', learningRate, ...
'GradientThreshold', gradientThreshold, ...
'ExecutionEnvironment', executionEnvironment);
The input size is 12, so there are 12 features.
numClasses is 4, so I am classifying it into 4 class.
But it gives the following error when I try to run it
"
Error in test123_20240727 (line 195)
net=trainNetwork(XTrain, YTrain, layers, options);
Caused by:
Layer 'add': Unconnected input. Each layer input must be connected to the output of another layer.
"
line 195 is "net=trainNetwork(XTrain, YTrain, layers, options);"
Can anyone help me with this?
  7 个评论
Umar
Umar 2024-7-29
Hi @ haohaoxuexi1,
If you are still having issues with modifying your code, please let us know. We will be happy to help you out.
haohaoxuexi1
haohaoxuexi1 2024-7-29
@Umar Hi Umar, I am good at the moment. Will let u know if I have further question.

请先登录,再进行评论。

采纳的回答

Joss Knight
Joss Knight 2024-7-29

You've passed layers instead of lgraph to trainNetwork.

  2 个评论
Umar
Umar 2024-7-29
@Joss Knight, Thanks for jumping in. Please advice how to use lgraph to trainNetwork by providing code snippet. Again, thanks for your cooperation.
Joss Knight
Joss Knight 2024-8-13
net=trainNetwork(XTrain, YTrain, lgraph, options);
instead of
net=trainNetwork(XTrain, YTrain, layers, options);

请先登录,再进行评论。

更多回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Image Data Workflows 的更多信息

产品


版本

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by