How does RL algorithm work with RNNs?

3 次查看(过去 30 天)
Hi,
I noticed that Matlab 2021a allows users to use RL algorithms, such as DDPG, with RNN in the deep neural network structure. This is great as it could benefit continuous control problems with time delay and time-dependent parameters.
However, I am wondering about the algorithm used by Matlab for the RL RNN learning process. RNN learn through backpropagation through time (BPTT), therefore, the sampled states for BPTT must be in series. On the other hand, RL algorithms (such as DDPG) learn by sampling random samples from the experience buffer; therefore, the algorithms does not integrate naturally compared to the conventional MLPNN structure. How does Matlab work with this? Is there any paper that I can refrence?
Next, I am also curious about the RNN BPTT execution in MATLAB. In RL, an episode could have hundred to thousands of time steps and RNN is usually expected to keep a memory of the states in each time step (referring to the unrolled structure) in order to learn the weights and bias for its' internal state. Does the series terminate at the end of every episode to update the RNN? Will this consume significantly more memory?
Thank you very much.
  1 个评论
Tech Logg Ding
Tech Logg Ding 2021-2-23
Bumping this question. After looking into the documentation, I've not found any information on how updates with RNN in DNN works. This paper (https://academic.oup.com/jigpal/article/18/5/620/751594?login=true) also describes that random episodes should be sampled with a short series for to train its' lstm network to work effectively. Does the RL toolbox include this?

请先登录,再进行评论。

采纳的回答

Takeshi Takahashi
Takeshi Takahashi 2021-2-24
Hi,
rlDDPGAgent with RNN first randomly samples B sequences (trajectories) from the experience buffer, where B is MiniBatchSize. Then, it randomly selects the starting point of each sampled sequence if the sequence is longer than L, where L is SequenceLength you specified. The end point of the sequence will be determined by the starting point and L so that the length becomes L.
Suppose some sampled sequences from the experience buffer are shorter than L. In that case, the sequences are padded with fake samples so that all short sequences in a batch have the same length (L). We apply masking to those padded samples, and the padded samples don't affect the BPTT.
We use these short sequences as a batch for BPTT. MiniBatchSize and SequenceLength control the size of the batch. Bigger MiniBatchSize and SequenceLength require more memory space during BPTT.
I hope this clarifies your question.
Thank you.
  2 个评论
Tech Logg Ding
Tech Logg Ding 2021-2-28
Got it! Thank you very much! Does the other algorithms such as TD3 and SAC use the same sampling method?

请先登录,再进行评论。

更多回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Sequence and Numeric Feature Data Workflows 的更多信息

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by