Problem with LSTM and PPO reinforcement learning

21 次查看(过去 30 天)
I am using a CNN network as a critic RL and it is wroking fine;
criticNetwork = [
imageInputLayer(obsSize,Normalization="none")
convolution2dLayer(8,16, ...
Stride=1,Padding=1,WeightsInitializer="he")
reluLayer
convolution2dLayer(4,8, ...
Stride=1,Padding="same",WeightsInitializer="he")
reluLayer
fullyConnectedLayer(368,WeightsInitializer="he")
reluLayer
fullyConnectedLayer(256,WeightsInitializer="he")
reluLayer
fullyConnectedLayer(64,WeightsInitializer="he")
reluLayer
fullyConnectedLayer(1)
];
, but when I am trying to revise it and use LSTM I can see the following error:
criticNetwork = [
imageInputLayer(obsSize,Normalization="none")
convolution2dLayer(8,16, ...
Stride=1,Padding=1,WeightsInitializer="he")
reluLayer
convolution2dLayer(4,8, ...
Stride=1,Padding="same",WeightsInitializer="he")
reluLayer
fullyConnectedLayer(368,WeightsInitializer="he")
reluLayer
fullyConnectedLayer(256,WeightsInitializer="he")
reluLayer
lstmLayer(128)
fullyConnectedLayer(1)
];
Error using rl.train.marl.MultiAgentTrainer/run
There was an error executing the ProcessExperienceFcn for block "rlAreaCoverage3S/Agent A (Red)".
Caused by:
Error using rl.function.AbstractFunction/evaluate
Unable to evaluate function model.
Error in rl.function.rlValueFunction/getValue (line 82)
[value, state, batchSize, sequenceLength] = evaluate(this, observation);
Error in rl.advantage.generalizedAdvantage (line 10)
NextStateValue = getValue(StateValueEstimator, Experiences.NextObservation);
Error in advantage (line 19)
[Advantage,TDTarget] =
rl.advantage.generalizedAdvantage(Experiences,ValueFunction,DiscountFactor,NameValueArgs.GAEFactor);
Error in rl.agent.rlPPOAgent.computeAdvantages (line 235)
[advantages, tdTargets] = advantage(batchExperiences,critic, ...
Error in rl.agent.rlPPOAgent/learn_ (line 79)
advantageData = this.computeAdvantages(...
Error in rl.agent.AbstractAgent/learn (line 29)
this = learn_(this,experience);
Error in rl.train.marl.MultiAgentTrainer>localProcessExpFcn (line 168)
learn(agent,exp);
Error in rl.env.internal.FunctionHandlePolicyExperienceProcessor/processExperience_ (line 31)
[this.Policy_,this.Data_] = feval(this.Fcn_,...
Error in rl.env.internal.ExperienceProcessorInterface/processExperienceInternal_ (line 137)
processExperience_(this,experience,getEpisodeInfoData(this));
Error in rl.env.internal.ExperienceProcessorInterface/processExperience (line 78)
stopsim = processExperienceInternal_(this,experience,simTime);
Error in rl.simulink.blocks.PolicyProcessExperience/stepImpl (line 45)
stopsim = processExperience(this.ExperienceProcessor_,experience,simTime);
Error in Simulink.Simulation.internal.DesktopSimHelper
Error in Simulink.Simulation.internal.DesktopSimHelper.sim
Error in Simulink.SimulationInput/sim
Error in rl.env.internal.SimulinkSimulator>localSim (line 259)
simout = sim(in);
Error in rl.env.internal.SimulinkSimulator>@(in)localSim(in,simPkg) (line 171)
simfcn = @(in) localSim(in,simPkg);
Error in MultiSim.internal.runSingleSim
Error in MultiSim.internal.SimulationRunnerSerial/executeImplSingle
Error in MultiSim.internal.SimulationRunnerSerial/executeImpl
Error in Simulink.SimulationManager/executeSims
Error in Simulink.SimulationManagerEngine/executeSims
Error in rl.env.internal.SimulinkSimulator/simInternal_ (line 172)
simInfo = executeSims(engine,simfcn,getSimulationInput(this));
Error in rl.env.internal.SimulinkSimulator/sim_ (line 78)
out = simInternal_(this,simPkg);
Error in rl.env.internal.AbstractSimulator/sim (line 30)
out = sim_(this,simData,policy,processExpFcn,processExpData);
Error in rl.env.AbstractEnv/runEpisode (line 144)
out = sim(simulator,simData,policy,processExpFcn,processExpData);
Error in rl.train.marl.MultiAgentTrainer/run (line 58)
out = runEpisode(...
Error in rl.train.TrainingManager/train (line 429)
run(trainer);
Error in rl.train.TrainingManager/run (line 218)
train(this);
Error in rl.agent.AbstractAgent/train (line 83)
trainingResult = run(trainMgr,checkpoint);
Error in LSTMTrainMultipleAgentsForAreaCoverageExampleORGFinalLATM (line 272)
result = train([agentA,agentB,agentC],env,trainOpts);
Caused by:
Not enough input arguments.
Error in rl.train.TrainingManager/train (line 429)
run(trainer);
Error in rl.train.TrainingManager/run (line 218)
train(this);
Error in rl.agent.AbstractAgent/train (line 83)
trainingResult = run(trainMgr,checkpoint);

回答(1 个)

Emmanouil Tzorakoleftherakis
Hi,
When you set up LSTM networks, the input layer needs to be the sequenceInput layer. See for example here:
  3 个评论
Emmanouil Tzorakoleftherakis
Both sequenceInputLayer and ImageInputLayer are input layers. You can only have one input layer per NN, which is why you are getting this error. Given that you want to use an LSTM, you should use the sequenceInputLayer. Take a look at this page to see how you can create LSTMs for images:
Ali Farid
Ali Farid 2023-10-29
Sir, There is a new error from the following revised codes, and I guess it is related to dlnetwork
Error using dlnetwork/initialize
Invalid network.
Error in dlnetwork (line 218)
net = initialize(net, dlX{:});
Caused by:
Example inputs: Incorrect number of example network inputs. 0 example network inputs provided but network has 2
inputs including 1 unconnected layer inputs.
Layer 'fold': Detected unsupported layer. The network must not contain unsupported layers.
Layer 'unfold': Detected unsupported layer. The network must not contain unsupported layers.
Layer 'unfold': Unconnected input. Each input must be connected to input data or to the output of another
layer.
Detected unconnected inputs:
input 'miniBatchSize'
The code is:
criticNetwork = [
sequenceInputLayer(obsSize)
sequenceFoldingLayer('Name','fold')
convolution2dLayer(8,16, ...
Stride=1,Padding=1,WeightsInitializer="he")
reluLayer
convolution2dLayer(4,8, ...
Stride=1,Padding="same",WeightsInitializer="he")
reluLayer
fullyConnectedLayer(368,WeightsInitializer="he")
reluLayer
fullyConnectedLayer(256,WeightsInitializer="he")
reluLayer
sequenceUnfoldingLayer('Name','unfold')
flattenLayer('Name','flatten')
lstmLayer(128)
fullyConnectedLayer(1)
];
criticNetwork = dlnetwork(criticNetwork);

请先登录,再进行评论。

产品


版本

R2022a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by