Deep learning: predict and forward give very inconsistent result with batch normalisation

11 次查看(过去 30 天)
Hi there,
I have a dnn with BatchNormalization layers. I did [Uf, statef] = forward(dnn_net, XT); "statef" returned from forward call should contain the learned mean and variance of the BN layers. I then update the dnn_net.State = statef so that the dnn_net.State is updated with the learned mean and var from the forward call. Then I did Up = predict(dnn_net,XT) [on the same data XT as the forward call]. Then I compare the mse(Up,Uf) and it turns out to be quite large (e.g 2.xx). This is totally not expected. Please help
I have created a simple test script to show the problem below:
numLayer = 9;
numNeurons = 80;
layers = featureInputLayer(2);% 2 input features
for i = 1:numLayer-1
layers = [
layers
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
geluLayer
];
end
layers = [
layers
fullyConnectedLayer(4)];
dnn_net = dlnetwork(layers,'Initialize',true);
dnn_net = expandLayers(dnn_net);% expand the layers in residual blocks
dnn_net = dlupdate(@double,dnn_net);
X = 2*pi*dlarray(randn(1,1000),"CB");
T = 2*pi*dlarray(randn(1,1000),"CB");
XT = cat(1,X,T);
% compute DNN output using forward function, statef contains
% batchnormlayers (learned means and learned var)
[Uf,statef] = forward(dnn_net,XT);
dnn_net.State = statef;% update the dnn_net.State (so that it has the BN layers updated learned means and var
Up = predict(dnn_net,XT);
% DNN output between predict call and forward call shoudl be the same in
% this case (becausae the dnn_net.State is updated with the same learned mean/var from forward calls.
% However, it is not the case. The err is quite large
plot(Uf(1,:),'r-');hold on; plot(Up(1,:),'b-');
err = mse(Uf(1,:),Up(1,:))
err =
1x1 dlarray 0.3134

回答(1 个)

Matt J
Matt J 2024-8-25
编辑:Matt J 2024-8-25
To get agreement, you need to allow the trained mean and variance to converge:
numLayer = 9;
numNeurons = 80;
layers = featureInputLayer(2);% 2 input features
for i = 1:numLayer-1
layers = [
layers
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
geluLayer
];
end
layers = [
layers
fullyConnectedLayer(4)];
dnn_net = dlnetwork(layers,'Initialize',true);
dnn_net = expandLayers(dnn_net);% expand the layers in residual blocks
dnn_net = dlupdate(@double,dnn_net);
X = 2*pi*dlarray(randn(1,1000),"CB");
T = 2*pi*dlarray(randn(1,1000),"CB");
XT = cat(1,X,T);
% compute DNN output using forward function, statef contains
% batchnormlayers (learned means and learned var)
for i=1:500
[Uf,statef] = forward(dnn_net,XT);
dnn_net.State = statef;% update the dnn_net.State (so that it has the BN layers updated learned means and var
end
Up = predict(dnn_net,XT);
% DNN output between predict call and forward call shoudl be the same in
% this case (becausae the dnn_net.State is updated with the same learned mean/var from forward calls.
% However, it is not the case. The err is quite large
plot(Uf(1,:),'r-');hold on; plot(Up(1,:),'--b');
err = mse(Uf(1,:),Up(1,:))
err =
1x1 dlarray 2.6776e-12
  6 个评论
Tom
Tom 2024-8-26
Hi Matt,
Thanks a lot for your explaination, Is there a way to set up BN layer so that it will be using a supplied mean and var for normalization when calling forward()? In my application, it is important for the BN to use the mean and var as constants for differentiation of U with respect to X and T. Thanks in advance.

请先登录,再进行评论。

类别

Help CenterFile Exchange 中查找有关 Install Products 的更多信息

产品


版本

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by