How do I pass an additional variable out of a custom loss function during the gradient call in reinforcement learning neural network training?

3 次查看(过去 30 天)
I am writing my own custom RL agent with a custom loss function.
My loss is the sum of two values, a value that is a function of the probability of the actions, and a value that represents the KL divergence between the target actor and the actor.
loss = logProbActions + eta * KLDivergence
This KL divergence is scaled by a factor, eta, that is updated every time the learn function is called.
The value of eta is calculated based on the previous value of eta and the difference between the KL divergence and a constraint i.e.
eta_next = eta - c*(constant_constraint - KLDivergence)
I need to pass the eta_next value out of the loss function so it can be stored in the agent object and passed on the next call but do not know how to do that. To be clear I set up the gradient call as follows:
ActorLoss = setLoss(ActorNetwork,@actorLossFunction);
ActorGradient = gradient(obj.actor,'loss-parameters',{stateBatch},actorLossData);
My loss function looks like this
function [loss] = actorLossFunction(policy,lossData)
Ideally I'd pass an extra term back like this,
function [loss,eta_next] = actorLossFunction(policy,lossData)
but that is not possible, I receive errors stating I have too many output terms.
There is no documentation that I can find on the "setLoss" function and I cannot figure out how to anonymize the actorLossFunction to insert the Agent object into the actorLossFunction so I can store the eta_next value in the agent itself.
Is there anyway to do this short of redoing all the math in the agent object? I want to avoid this as it would be very inefficient.

采纳的回答

Daniel Egan
Daniel Egan 2021-7-14
Found an answer, not sure how good it is but it works. I can pass in an object with the super class "handle"
classdef etaTracker < handle
that will allow me to modify terms within it and those changes will be reflected outside of the gradient call. Not sure if this is the most efficient way to do it but it works and is better than global variables in my opinion.

更多回答(0 个)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by