- For the observation buffer: number of observations * number of observation channels * batch size.
- For the action buffer: number of actions * number of action channels * batch size.
- For reward buffer: 1 * batch size
agent.learn data type issue, reinforcement learning toolbox
6 次查看(过去 30 天)
显示 更早的评论
I am working on a reinforcement learning study. Currently, I am trying to finalize the agent and make it learn from it's experiences. I can not show all of the code but this is the most important part I think:
%% Define action and observation specifications
ActionInfo = rlFiniteSetSpec([1 2 3]); % Actions that the agent is able to take
ObservationInfo = rlNumericSpec([30 10]); % This is what eventually be input for the neural network
% lots of code here ....
% Defining the everything in experience
CurrentState = env.reset();
action = agent.getAction(CurrentState); % Get action from agent
[nextState, reward, isDone, ~] = env.step(action); % Interact with environment
% Collect experience
experience = struct(...
'Observation', {num2cell(CurrentState)}, ...
'Action', {num2cell(action)}, ...
'Reward', reward, ...
'NextObservation', {num2cell(nextState)}, ...
'IsDone', isDone);
% Train the agent with the experience
agent = agent.learn(experience); % Update agent with experience
To elaborate, the currentState and nextState are matrices of 30 x 10 of datatype double, action is 1x1 cell, reward is datatype double, and isDone is logical. However, when passing to these to experience, the agent.learn function does not work because of these parts of code in the batchExperienceArray.m file (when not passing the variables with num2cell):
% batch observation, next observation
for ct = 1:numel(ObservationDimension)
BatchDim = numel(ObservationDimension{ct})+1;
% Observation
Observation = arrayfun(@(x) (x.Observation{ct}), ExpStructArray, 'UniformOutput', false);
ObservationArray{ct} = cat(BatchDim, Observation{:});
% NextObservation
NextObservation = arrayfun(@(x) (x.NextObservation{ct}), ExpStructArray, 'UniformOutput', false);
NextObservationArray{ct} = cat(BatchDim, NextObservation{:});
end
Action = [ExpStructArray.Action];
for ct = 1:numel(ActionDimension)
BatchDim = numel(ActionDimension{ct})+1;
ActionArray{ct} = cat(BatchDim,Action{ct,:});
end
Here the error is that brace indexing is not supported for the data type. When I do pass all the variables in experience like it is in the code above, the error becomes:
Error using rl.function.AbstractFunction/validateInputData_
Input data dimensions must match the dimensions specified in the corresponding observation and action info
specifications.
The question thus becomes: how can I pass the data correctly to the agent.learn with the experience, without all these errors? What am I missing here? If any more information is missing, let me know.
0 个评论
回答(1 个)
Avadhoot
2024-3-19
From the information provided in the question I infer that you are having problems with the dimensions of the observation and action matrices in the input data. You have also implemented batching in your code. The error you are facing is due to a dimension mismatch between the input data and the observation and action info specifications. There also might be an issue with how you pass the experience structure to the "learn" function. You have mentioned that if you pass the variables without the "num2cell" conversion, it again gives the error: " brace indexing is not supported for the data type". This is because the batching in the "learn" function expects the inputs to be cell arrays.
According to MATLAB documentation, there should be buffers to store experiences and the dimensions of each buffer must be as follows:
The source of your error might be that you have not formatted the observations and actions according to the batch size. Consider formatting the buffers in the dimensions mentioned above.
For more information on the training procedure, refer to the below example:
I hope this helps in getting an idea about the cause of the error.
3 个评论
Avadhoot
2024-3-19
For training a DQN agent you can take a look at the following example: https://www.mathworks.com/help/reinforcement-learning/ug/model-based-reinforcement-learning-using-custom-training-loop.html
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Training and Simulation 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!