Predict and update using LSTM

20 次查看(过去 30 天)
Tanmoy Chatterjee
Tanmoy Chatterjee 2021-8-11
Hi,
I am trying to solve a time forecasting problem using LSTM in Matlab. My previous questions posed inhttps://uk.mathworks.com/matlabcentral/answers/890637-data-preparation-for-time-forecasting-using-lstm?s_tid=srchtitle were answered really well. Thanks for that.
(Q1) In this context, I wanted to ask if there are any differences between 'predict' and 'predictAndUpdateState' in the prediction step using LSTM, other than the point that 'predict' returns a sequence of predictions while 'predictAndUpdateState' makes predictions one step at a time? I am asking this because 'predict' also updates the network state between each prediction.
(Q2) I am training the LSTM model on the first 900 seconds (training set) and forecasting the response for the next 100 seconds (test set). So, technically, if I am using YPred = predict(net,XTest), this is updating 'net' according to XTest which is the test input and as per the forecasting problem formulation, I do not have the test set and hence shouldn't be using that. I should rather be updating the model and base my predictions on only XTrain and YTrain as follows. But the predictions are not at all good with this. Can you provide some suggestions if using XTest to predict is okay or how to improve the prediction of the following code?
for n = 1:numObs
[net, Y] = predictAndUpdateState(net, XTrain{n});
Y = Y(:, end);
Yseq = [];
for t = 1:numSteps
[net, Y] = predictAndUpdateState(net, Y);
Yseq = cat(2, Yseq, Y);
end
YTest{n} = Yseq;
net = resetState(net);
end
I am using the following network configuration:
numHiddenUnits = 100;
options = trainingOptions('adam', ...
'MaxEpochs',200, ...
'GradientThreshold',1, ...
'InitialLearnRate',0.005, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',125, ...
'LearnRateDropFactor',0.2, ...
'MiniBatchSize',100, ...
'Verbose',1, ...
'Plots','training-progress');

回答(1 个)

Aneela
Aneela 2024-2-22
Hi Tanmoy Chatterjee
Addressing your first query–
predict – It is used to compute the network’s output by processing the entire input sequence at once and returns the corresponding sequence of predictions as output.
  • The “LSTM” network state is updated between predictions within the input sequence, but the state is not returned or maintained after the function call.
predictAndUpdateState - This function is used to make a prediction and return the updated network state.
  • This allows you to maintain and manage the network state across multiple calls to the function, which is crucial when generating predictions in a loop, where each prediction becomes the input for the next time step.
Addressing your second query –
  • Ensure the following if you do not have test set and would like to improve the prediction of the network.
  • Adjust the hyperparameters of the “LSTM”, such as the number of layers, the number of units in each layer, the learning rate, and batch size.
  • To prevent overfitting, consider adding dropout or L2 regularization to your “LSTM” layers.
  • Try to ensemble with other forecasting models like “ARIMA”. Use ARIMA to make initial forecasts and then use the LSTM to model the error of the ARIMA predictions (learning from the past errors).
  • If you are having test data, “xTest”, it is appropriate to use it with “predict” function, to evaluate the performance of the model after training the network.
  • When using “predict”, the network weights are not updated based on the test set; you are only updating the internal state of the network to make predictions based on the learned parameters.
Refer to the following links for further guidance on the “predict” and “predictAndUpdateState”.
predict –
  1 个评论
Imola Fodor
Imola Fodor 2024-2-26
hi, i believe it is expected that the output in open loop forecasting should be the same with predict and with PredictAndUpdatestate (in a loop) right? by open loop i mean i dont have a previously predicted state as successive input. I would need the predictandupdatestate to not store continuosly the previous inputs explicitly and thats why i would ideally avoid to use the predict function.

请先登录,再进行评论。

类别

Help CenterFile Exchange 中查找有关 Sequence and Numeric Feature Data Workflows 的更多信息

产品


版本

R2020a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by