Is it possible to share common weights and bias among different LSTM layers?

8 次查看(过去 30 天)
I am building a network looks like the figure below.
There are three LSTM layers, namely LSTM_common_1, LSTM_common_2 and LSTM_common_3.
Can I retrict their weights and bias so that all of the LSTM_common_x shares the same set of weights and bias?
2020-02-05 17_50_15-Clipboard.png

回答(1 个)

Conor Daly
Conor Daly 2023-2-17
One way to share weights like this is to use nested layers -- layers which have learnable parameters defined by neural networks. The general idea is to create a layer which uses the shared sub-network (which in this case is just a single LSTM layer) as appropriate.
Here's an example for the case above:
classdef commonLSTMLayer < nnet.layer.Layer ...
& nnet.layer.Formattable ...
& nnet.layer.Acceleratable
properties (Learnable)
Network
end
methods
function this = commonLSTMLayer(numHiddenUnits, numOutputs, args)
arguments
numHiddenUnits (1,1) {mustBePositive, mustBeInteger}
numOutputs (1,1) {mustBePositive, mustBeInteger}
args.OutputMode {mustBeMember(args.OutputMode, ["last","sequence"])}= "sequence"
args.Name {mustBeTextScalar}
end
this.Name = args.Name;
layer = lstmLayer(numHiddenUnits, OutputMode=args.OutputMode);
this.Network = dlnetwork(layer, Initialize=false);
this.NumOutputs = numOutputs;
this.OutputNames = "out" + (1:numOutputs);
end
function varargout = predict(this, X)
varargout = cell(1,this.NumOutputs);
for n = 1:this.NumOutputs
varargout{n} = predict(this.Network, X(n,:,:));
end
end
end
end
Using this layer we can construct the network as follows:
numInputChannels = 3;
numHiddenUnits = 64;
layers = [ sequenceInputLayer(numInputChannels)
commonLSTMLayer(numHiddenUnits, numInputChannels, OutputMode="last", Name="lstm")
fullyConnectedLayer(2, Name="fc1")
concatenationLayer(1, 3, Name="cat")
regressionLayer() ];
lg = layerGraph(layers);
lg = addLayers(lg, fullyConnectedLayer(2, Name="fc2"));
lg = addLayers(lg, fullyConnectedLayer(2, Name="fc3"));
lg = connectLayers(lg, "lstm/out2", "fc2");
lg = connectLayers(lg, "lstm/out3", "fc3");
lg = connectLayers(lg, "fc2", "cat/in2");
lg = connectLayers(lg, "fc3", "cat/in3");
analyzeNetwork(lg)

类别

Help CenterFile Exchange 中查找有关 Image Data Workflows 的更多信息

产品


版本

R2019a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by