What does function predict() in Deep Learning Toolbox do?

3 次查看(过去 30 天)
Hi, I follow the example of this
and made a little modification, namely by not using predict() function but calling predictAndUpdateState() to predict the target one by one.
In this way I get a much worse predition result (brown line) as predict() (yellow line).
Can anyone explain this?
The only different part is
% opt1: pure use feature variables as input
net = resetState(net);
YPred = [];
for i = 1:numel(XTest)
[net, temp] = predictAndUpdateState(net, XTest(:,i), 'ExecutionEnvironment', 'cpu');
YPred(:,i) = cell2mat(temp);
end
y1 = YPred;
Whole codes:
[~,~,data] = xlsread('ET_1.xlsx');
data_mat = cell2mat(data);
XTrain = (data_mat(:,4:8))';
XTrain = num2cell(XTrain,1);
YTrain = (data_mat(:,3))';
YTrain = num2cell(YTrain,1);
%%Define Network Architecture
featureDimension = size(XTrain{1},1);
numResponses = size(YTrain{1},1);
numHiddenUnits = 500;
layers = [ ...
sequenceInputLayer(featureDimension)
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(500) %%50
dropoutLayer(0.1) %%0.5
fullyConnectedLayer(numResponses)
regressionLayer
];
maxepochs = 500;
miniBatchSize = 1;
options = trainingOptions('adam', ... %%adam
'MaxEpochs',maxepochs, ...
'GradientThreshold',1, ...
'InitialLearnRate',0.005, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',125, ...
'LearnRateDropFactor',0.2, ...
'Verbose',0, ...
'Plots','training-progress');
%%Train the Network
net = trainNetwork(XTrain,YTrain,layers,options);
%% Test the Network
[~,~,data] = xlsread('ET_2.xlsx');
data_mat = cell2mat(data);
XTest = (data_mat(:,4:8))'; XTest = num2cell(XTest,1);
YTest = (data_mat(:,3))'; YTest = num2cell(YTest,1);
% opt1: pure use feature variables as input
net = resetState(net);
YPred = [];
for i = 1:numel(XTest)
[net, temp] = predictAndUpdateState(net, XTest(:,i), 'ExecutionEnvironment', 'cpu');
YPred(:,i) = cell2mat(temp);
end
y1 = YPred;
% opt2: predict()
net = resetState(net);
YPred = predict(net, XTest);
y2 = (cell2mat(YPred)); %have to transpose as plot plots columns
%%
figure; hold all
yRef = (cell2mat(YTest)');
plot(yRef, '-o')
plot(y1, '-x')
plot(y2, '-s')
  1 个评论
Song Decn
Song Decn 2021-5-10
% Opt1:
% yTrain = predict(net, xTrainStandardized);
% yTrain = cell2mat(yTrain);
% Opt2:
% yTrain = [];
% for i = 1:numel(xTrainStandardized)
% [net, tmp] = predictAndUpdateState(net, xTrainStandardized(i));
% yTrain(i) = cell2mat(tmp);
% end
% Opt3:
[net, tmp] = predictAndUpdateState(net, xTrainStandardized);
yTrain = cell2mat(tmp);
these 3 ways to calculate responses give different values? Why?

请先登录,再进行评论。

回答(1 个)

Vidip
Vidip 2024-2-21
编辑:Vidip 2024-2-21
The reason you are not getting good results with ‘predictAndUpdateState’ in a loop compared to using ‘predict’ is due to how the LSTM network's state is managed between predictions. The predict function treats each sequence as independent and resets the LSTM state automatically between each prediction, which is appropriate when your test sequences are not temporally related. However, when using ‘predictAndUpdateState’ in a loop without resetting the state after each prediction, the LSTM network's internal state carries over from one prediction to the next.
This means that the network's prediction for each data point is influenced by all the previous data points, which is not suitable if the sequences in ‘XTest’ are supposed to be independent. The accumulation of state information from unrelated sequences can lead to inaccurate predictions, as the network is incorrectly using historical context from separate sequences to make its predictions.
For further information, refer to the documentation link below:

类别

Help CenterFile Exchange 中查找有关 Sequence and Numeric Feature Data Workflows 的更多信息

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by