deep learning toolbox : Input format mismatch
4 次查看(过去 30 天)
显示 更早的评论
I want to build a lstm network with an attention mechanism.
According to a paper, I need to add the output layer of lstm with the processed hidden layer.
But when I use the additionlayer I get a message that the input format doesn't match!
I found that the output layer is in 10(C) x 1(B) x 1(T) format, while the hidden layer is in 10(C) x 1(B) format.
Normally they should be addable because they are the same size in the first two dimensions and the output layer has a length of 1 in the third dimension.
What do I need to do to make them add up. Thanks!
Following is the deep network I designed according to the paper. It is a CNN-LSTM-Attention network. It is used to predict a time series sequences.
sequence = sequenceInputLayer([kim,1],"Name",'sequence');
conv1 = convolution1dLayer(3,64,'Name','conv1','Padding','causal');
batchnorm1 = batchNormalizationLayer('Name','batchnorm1');
relu1 = reluLayer('Name','relu1');
conv2 = convolution1dLayer(3,64,'Name','conv2','Padding','causal');
batchnorm2 = batchNormalizationLayer('Name','batchnorm2');
relu2 = reluLayer('Name','relu2');
conv3 = convolution1dLayer(3,64,'Name','conv3','Padding','causal');
batchnorm3 = batchNormalizationLayer('Name','batchnorm3');
add1 = additionLayer(2,"Name",'add1');
relu3 = reluLayer('Name','relu3');
maxpool = maxPooling1dLayer(10,'Name','maxpool');
flatten1 = flattenLayer("Name",'flatten1');
lstm = lstmLayer(10,'Name','lstm','HasStateOutputs',1);
fc1 = fullyConnectedLayer(10,'Name','fc1');
tanh1 = tanhLayer('Name','tanh1');
softmax1 = softmaxLayer('Name','softmax1');
multiplication1 = multiplicationLayer(2,'Name','multiplication1');
fc2 = fullyConnectedLayer(10,'Name','fc2');
tanh2 = tanhLayer('Name','tanh2');
softmax2 = softmaxLayer('Name','softmax2');
multiplication2 = multiplicationLayer(2,'Name','multiplication2');
add2 = additionLayer(3,"Name",'add2');
fc3 = fullyConnectedLayer(10,'Name','fc3');
softmax3 = softmaxLayer('Name','softmax3');
regression = regressionLayer('Name','regression');
lgraph = layerGraph;
lgraph = addLayers(lgraph,sequence);
lgraph = addLayers(lgraph,conv1);
lgraph = addLayers(lgraph,batchnorm1);
lgraph = addLayers(lgraph,relu1);
lgraph = addLayers(lgraph,conv2);
lgraph = addLayers(lgraph,batchnorm2);
lgraph = addLayers(lgraph,relu2);
lgraph = addLayers(lgraph,conv3);
lgraph = addLayers(lgraph,batchnorm3);
lgraph = addLayers(lgraph,add1);
lgraph = addLayers(lgraph,relu3);
lgraph = addLayers(lgraph,maxpool);
lgraph = addLayers(lgraph,flatten1);
lgraph = addLayers(lgraph,lstm);
lgraph = addLayers(lgraph,fc1);
lgraph = addLayers(lgraph,tanh1);
lgraph = addLayers(lgraph,softmax1);
lgraph = addLayers(lgraph,multiplication1);
lgraph = addLayers(lgraph,fc2);
lgraph = addLayers(lgraph,tanh2);
lgraph = addLayers(lgraph,softmax2);
lgraph = addLayers(lgraph,multiplication2);
lgraph = addLayers(lgraph,add2);
lgraph = addLayers(lgraph,fc3);
lgraph = addLayers(lgraph,softmax3);
lgraph = addLayers(lgraph,regression);
lgraph = connectLayers(lgraph,'sequence','conv1');
lgraph = connectLayers(lgraph,'conv1','batchnorm1');
lgraph = connectLayers(lgraph,'batchnorm1','relu1');
lgraph = connectLayers(lgraph,'batchnorm1','add1/in1');
lgraph = connectLayers(lgraph,'relu1','conv2');
lgraph = connectLayers(lgraph,'conv2','batchnorm2');
lgraph = connectLayers(lgraph,'batchnorm2','relu2');
lgraph = connectLayers(lgraph,'relu2','conv3');
lgraph = connectLayers(lgraph,'conv3','batchnorm3');
lgraph = connectLayers(lgraph,'batchnorm3','add1/in2');
lgraph = connectLayers(lgraph,'add1/out','relu3');
lgraph = connectLayers(lgraph,'relu3','maxpool');
lgraph = connectLayers(lgraph,'maxpool','flatten1');
lgraph = connectLayers(lgraph,'flatten1','lstm');
lgraph = connectLayers(lgraph,'lstm/out','add2/in1');
lgraph = connectLayers(lgraph,'lstm/hidden','fc1');
lgraph = connectLayers(lgraph,'fc1','tanh1');
lgraph = connectLayers(lgraph,'tanh1','softmax1');
lgraph = connectLayers(lgraph,'softmax1','multiplication1/in1');
lgraph = connectLayers(lgraph,'lstm/hidden','multiplication1/in2');
lgraph = connectLayers(lgraph,'multiplication1','add2/in2');
lgraph = connectLayers(lgraph,'lstm/cell','fc2');
lgraph = connectLayers(lgraph,'fc2','tanh2');
lgraph = connectLayers(lgraph,'tanh2','softmax2');
lgraph = connectLayers(lgraph,'softmax2','multiplication2/in1');
lgraph = connectLayers(lgraph,'lstm/cell','multiplication2/in2');
lgraph = connectLayers(lgraph,'multiplication2','add2/in3');
lgraph = connectLayers(lgraph,'add2','fc3');
lgraph = connectLayers(lgraph,'fc3','softmax3');
lgraph = connectLayers(lgraph,'softmax3','regression');
plot(lgraph)
analyzeNetwork(lgraph)
0 个评论
采纳的回答
David Ho
2023-10-9
Hello 湃林,
It looks like the error in the multiplication layer comes from the fact that one of the inputs, from the LSTM layer, has a time dimension, while the other two inputs do not. That means that if, for example, your sequences have length 100, the layer is trying to multiply a [10 1] matrix, a [10 1] matrix, and a [10 1 100] array, which is not a supported operation (multiplication layers do not permit implicit expansion).
One method of resolving the issue is to change the output mode of the LSTM layer from "sequence" to "last", so that it only outputs a single time step:
lstm = lstmLayer(10,'Name','lstm','HasStateOutputs',1,'OutputMode','last');
However, I'm not sure if this is the result you need. If the time steps are important and you want to perform multiplication with implicit expansion, you can substitute the multiplication layer (and the following addition layer) for a functionLayer:
multiplication1 = functionLayer(@(x,y,z) x .* y .* z, 'Formattable', true);
add2 = functionLayer(@(x,y) x + y, 'Formattable', true);
Hopefully one of these options will give you the result you're looking for.
As a final note, if you want to implement the self-attention mechanism from the Attention is all you need paper, you may wish to look at Deep Learning Toolbox's implementation of the self-attention layer:
This handles the matrix multiplication of the keys, queries and values internally, so you don't need to implement it yourself.
Best regards,
David
更多回答(0 个)
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Image Data Workflows 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!