To find the policy function of your post-learning controller using reinforcement learning, you can try use the trained agent to evaluate actions based on given states.
- generatePolicyFunction: This function is typically used to generate a standalone policy function from a trained reinforcement learning agent. This function can be useful if you want to deploy the policy outside of the reinforcement learning environment or integrate it into a larger system.
- getAction: This method is used to obtain the action from the agent given a specific state. It is more straightforward for evaluating the policy in a simulation or analysis context.
For your purpose of evaluating the policy function (torque) for specific states (angle and angular speed), using getAction is more appropriate. It allows you to directly query the agent for actions based on the states you specify.
For better understanding of these function, you can refer to the below documentation:
- https://www.mathworks.com/help/reinforcement-learning/ref/rl.policy.rlmaxqpolicy.generatepolicyfunction.html
- https://www.mathworks.com/help/reinforcement-learning/ref/rl.policy.rlmaxqpolicy.getaction.html
If you are using the "Soft Actor-Critic" (SAC) algorithm, the agent consists of both actor and critic networks. You can save the trained agent using the save function in MATLAB, which will include these networks. This saves the entire agent, including its policy (actor network) and value function (critic network).
Revised code for the same is given below:
% Load the trained agent
% Define the state space (angle and angular velocity)
N = 5; % Number of divisions
Angle = linspace(-3.14, -4.71, N);
Velocity = linspace(0, -20, N);
[AngleGrid, VelocityGrid] = meshgrid(Angle, Velocity);
State = [AngleGrid(:), VelocityGrid(:)]; % Combination of states
% Preallocate the policy function output
F = zeros(size(State, 1), 1); % Policy function (Torque predicted by trained agent)
% Evaluate the policy for each state
for i = 1:size(State, 1)
F(i) = getAction(agent, State(i, :));
end
% Save the trained agent
save('trainedAgent.mat', 'agent');
For better understanding of "SAC" algortihm, refer to the following documentation.
Hope that helps!