All the gradients in a custom DDPG agent class are zero!

6 次查看(过去 30 天)
Hi!
I'm creating a custom DDPG agent inside a class. The train function of the Reinforcement Learning Toolbox run. This is the class (the core of the code is in the function learnImpl):
classdef CustomDDPGAgent < rl.agent.CustomAgent
%% Properties (set properties attributes accordingly)
properties
Actor
Critic
Options
ActorOptimizer
CriticOptimizer
NumObservation
NumAction
TargetActor
TargetCritic
ReplayBuffer
ExplorationNoise
ObservationBuffer
ActionBuffer
RewardBuffer
NextObservationBuffer
pippo
Counter
BufferLength % New property to define buffer length
% criticLoss
% criticParams
criticGradients
actionsFromActor
qValuesFromCritic
actorGradients
actorParams
actnet
criticnet
experiences
end
properties (Access = private)
NumUnseenExperiences % New property to track the number of unseen experiences
end
properties (Access = private,Transient)
% Accelerated gradient function, not saved with the agent
AccelGradFcnCrt = []
AccelGradFcnAct = []
end
%% Necessary Functions
%======================================================================
% Implementation of public function
%======================================================================
methods
function obj = CustomDDPGAgent(Actor, Critic, Options, SampleTime, TargetActor, TargetCritic, actnet, criticnet)
% Call the abstract class constructor.
obj = obj@rl.agent.CustomAgent();
obj.ObservationInfo = Actor.ObservationInfo;
obj.ActionInfo = Actor.ActionInfo;
% Register sample time. For MATLAB environment, use -1.
obj.SampleTime = SampleTime;
% Register actor, critic, and agent options.
obj.pippo = 0;
obj.Actor = Actor;
obj.Critic = Critic;
obj.Options = Options;
obj.actnet = actnet;
obj.criticnet = criticnet;
obj.ActorOptimizer = rlOptimizer(Options.OptimizerOptionsActor);
obj.CriticOptimizer = rlOptimizer(Options.OptimizerOptionsCritic);
obj.ExplorationNoise = Options.ExplorationNoise;
% obj.criticLoss = [];
% obj.criticParams = [];
obj.criticGradients = [];
obj.actionsFromActor = [];
obj.qValuesFromCritic = [];
obj.actorGradients = [];
obj.actorParams = [];
obj.experiences = [];
obj.NextObservationBuffer = [];
% Cache the number of observations and actions.
obj.NumObservation = prod(obj.Actor.ObservationInfo.Dimension);
obj.NumAction = prod(obj.Actor.ActionInfo.Dimension);
% Initialize target networks and replay buffer.
obj.TargetActor = TargetActor;
obj.TargetCritic = TargetCritic;
obj.BufferLength = Options.ExperienceBufferLength; % Set buffer length
obj.NumUnseenExperiences = 0; % Initialize number of unseen experiences
% Initialize buffer and counter.
resetImpl(obj);
end
end
%======================================================================
% Implementation of abstract function
%======================================================================
methods (Access = protected)
function Action = getActionImpl(obj, Observation)
% Compute an action using the policy given the current
% observation.
Action = getAction(obj.Actor, Observation);
end
function Action = getActionWithExplorationImpl(obj, Observation)
% Compute an action using the exploration policy given the
% current observation.
% DDPG: Deterministic actors explore by adding noise to the action
Action = getAction(obj.Actor, Observation);
if iscell(Action)
Action = cell2mat(Action);
end
% Generate Ornstein-Uhlenbeck noise.
Noise = obj.ExplorationNoise.XPrev + obj.ExplorationNoise.MeanAttractionConstant * (obj.ExplorationNoise.Mu - obj.ExplorationNoise.XPrev) * obj.ExplorationNoise.Dt + obj.ExplorationNoise.CurrentStandardDeviation * randn(obj.NumAction, 1) * sqrt(obj.ExplorationNoise.Dt);
Action = Action + Noise;
% Store the noise for the next time step.
obj.ExplorationNoise.XPrev = Noise;
end
function Action = learnImpl(obj, Experience)
% Extract data from experience
Obs = Experience{1};
Action = Experience{2};
Reward = Experience{3};
NextObs = Experience{4};
IsDone = Experience{5};
% Save experience to replay buffer
obj.NumUnseenExperiences = obj.NumUnseenExperiences + 1;
if obj.NumUnseenExperiences <= obj.BufferLength
obj.ObservationBuffer(:,:,obj.NumUnseenExperiences) = Obs{1};
obj.ActionBuffer(:,:,obj.NumUnseenExperiences) = Action{1};
obj.RewardBuffer(:,obj.NumUnseenExperiences) = Reward;
obj.NextObservationBuffer(:,:,obj.NumUnseenExperiences) = NextObs{1};
else
% Buffer is full, overwrite oldest experience
obj.ObservationBuffer(:,:,(mod(obj.NumUnseenExperiences - 1, obj.BufferLength) + 1)) = Obs{1};
obj.ActionBuffer(:,:,(mod(obj.NumUnseenExperiences - 1, obj.BufferLength) + 1)) = Action{1};
obj.RewardBuffer(:,(mod(obj.NumUnseenExperiences - 1, obj.BufferLength) + 1)) = Reward;
obj.NextObservationBuffer(:,:,(mod(obj.NumUnseenExperiences - 1, obj.BufferLength) + 1)) = NextObs{1};
end
% Choose an action for the next state
Action = getActionWithExplorationImpl(obj, NextObs);
obj.Counter = obj.Counter + 1;
% Update the standard deviation
decayedStandardDeviation = obj.ExplorationNoise.CurrentStandardDeviation * (1 - obj.ExplorationNoise.StandardDeviationDecayRate);
obj.ExplorationNoise.CurrentStandardDeviation = max(decayedStandardDeviation, obj.ExplorationNoise.MinimumStandardDeviation);
BatchSize = obj.Options.MinibatchSize;
% Learn from episodic data if enough samples are available
if obj.NumUnseenExperiences >= BatchSize
% Sample experiences from replay buffer
obj.pippo = obj.pippo + 1;
obj.criticGradients = dlfeval(@trainCritic,obj.criticnet,{obj.NextObservationBuffer},{obj.ObservationBuffer},{obj.ActionBuffer},obj.RewardBuffer,IsDone,obj.TargetActor,obj.TargetCritic,obj.Critic,obj.Options.DiscountFactor);
obj.Counter = 1;
obj.NumUnseenExperiences = 0;
end
end
end
%% Optional Functions
%======================================================================
% Implementation of optional function
%======================================================================
methods (Access = protected)
function resetImpl(obj)
% (Optional) Define how the agent is reset before training.
resetBuffer(obj);
obj.NumUnseenExperiences = 0;
obj.Counter = 1;
obj.ExplorationNoise.XPrev = zeros(size(obj.ExplorationNoise.XPrev));
end
end
methods (Access = private)
function resetBuffer(obj)
obj.ObservationBuffer = dlarray(zeros(obj.NumObservation,1,obj.Options.MinibatchSize));
obj.ActionBuffer = dlarray(zeros(obj.NumAction,1,obj.Options.MinibatchSize));
obj.RewardBuffer = dlarray(zeros(1,obj.Options.MinibatchSize));
obj.NextObservationBuffer = dlarray(zeros(obj.NumObservation,1,obj.Options.MinibatchSize));
end
end
end
function out = trainCritic(criticnet,nextStates,states,actions,rewards,isDone,TargetActor,TargetCritic,Critic,DiscountFactor)
targetActions = getAction(TargetActor, nextStates);
targetQ = getValue(TargetCritic, nextStates, targetActions);
currentQ = getValue(Critic, states, actions);
targets = rewards + DiscountFactor * targetQ .* ~isDone;
criticLoss = 0.5 * mean((targets - currentQ).^2);
disp(criticLoss)
out = dlgradient(criticLoss,criticnet.Learnables);
end
For now I want to see if it correctly calculates the gradient of the critic parameters, so the class is not complete. But by running some training episodes, the gradient of all parameters always remains equal to zero.
The error is not related to the hyperparameters, but to the fact of how the gradient is calculated.
What is the error?
Thank you all in advance!

回答(1 个)

Avadhoot
Avadhoot 2024-4-10
As you mentioned the error might be related to how the gradients are calculated. Debugging gradient calculation is not straightforward. There are a few things you should check to ensure that the gradients are computed correctly.
1) Data types:
Ensure that all inputs to "dlgradient" (i.e., the inputs to "trainCritic" that are involved in computing "criticLoss") are "dlarray" objects with the proper data type (e.g., single or double) and have the "Gradient" attribute enabled. If the inputs are not dlarrays or the "Gradient" attribute is not enabled, "dlgradient" will not compute gradients.
2) Loss function:
Check if all the operations performed inside the loss function are differentiable. In your case they are so the issue may not stem from here,
3) Debug gradient calculations:
Before calling dlgradient, ensure that the "Learnables" of "criticnet" are dlarray objects with the "Gradient" attribute enabled. You can check this by inspecting the properties of "criticnet.Learnables" before the "dlgradient" call. You can also check within the loss function if the gradients are still zero at that stage.
I am looking forward to what you find. I hope this helps.

类别

Help CenterFile Exchange 中查找有关 Training and Simulation 的更多信息

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by