REINFORCE algorithm- unable to compute gradients on latest toolbox version

7 次查看(过去 30 天)
I have been trying to implement the REINFORCE algorithm using custom training loop.
The LSTM actor network inputs 50 timestep data of three states. Therefore a state is of dimension 3x50.
For computing gradients, the input data in the forllowing format
num_states x batchsize x N_TIMESTEPS = (3x1)x50x50.
In Reinforcement Learning toolbox version 1.3, the following line works perfectly.
% actor- the custom actor network , actorLossFunction- custom loss fn, lossData- custom variable
actorGradient = gradient(actor,@actorLossFunction,{reshape(observationBatch,[3 1 50 50])},lossData);
However, when I run the same code in the latest RL toolbox version 2.2, I get the following error:
------------------------------------------------------------------------------------------------------------------------------------------------------
Error using rl.representation.rlAbstractRepresentation/gradient
Unable to compute gradient from representation.
Error in simpleRLTraj (line 184)
actorGradient= gradient(actor,@actorLossFunction,{reshape(observationBatch,[3 1 50 50])},lossData);
Caused by:
Error using extractBinaryBroadcastData
dlarray is supported only for full arrays of data type double, single, or logical, or for full gpuArrays of
these data types.
------------------------------------------------------------------------------------------------------------------------------------------------------
I tried tracing back to the error but it get more complicated. How do I get an error for a code that works perfectly on the earlier version of RL toolbox?

采纳的回答

Joss Knight
Joss Knight 2022-4-5
编辑:Joss Knight 2022-4-5
What is
underlyingType(observationBatch)
underlyingType(lossData)
?
  5 个评论
Bhooshan V
Bhooshan V 2022-4-6
I found the issue. Apparently, the output of the neural network is a cell array and not a double type.
As a result of some sort of typecasting, the loss was of type cell array.
I found that we cannot convert a cell type to dlarray type using the dlarray() function which must have been used somewhere internally in the gradient() function.
example-
dlarray({3})
Error using dlarray
dlarray is supported only for full arrays of data type double, single, or logical, or for full gpuArrays of these data types.
I have resolved the error. Thank you for helping me realize this.

请先登录,再进行评论。

更多回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Sequence and Numeric Feature Data Workflows 的更多信息

产品


版本

R2022a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by