Main Content

Train MBPO Agent to Balance Continuous Cart-Pole System

This example shows how to train a model-based policy optimization (MBPO) agent to balance a continuous action space cart-pole system modeled in MATLAB®. For more information on MBPO agents, see Model-Based Policy Optimization (MBPO) Agent.

MBPO agents use an environment model to generate more experiences while training a base agent. In this example, the base agent is a soft actor-critic (SAC) agent.

The built-in MBPO agent is based on a model-based policy optimization algorithm in [1]. The original MBPO algorithm trains an ensemble of stochastic models. In contrast, this example trains an ensemble of deterministic models. For an example in which an MBPO agent is implemented using a custom training loop, see Model-Based Reinforcement Learning Using Custom Training Loop.

The following figure summarizes the algorithm used in this example. During training the MBPO agent collects real experiences resulting from interactions with the environment. The MBPO agent uses these experiences to train its internal environment model. Then, it uses this model to generate experiences without interacting with the actual environment Finally, the MBPO agent uses the real experiences and generated experiences to train the SAC base agent.

Cart-Pole MATLAB Environment

For this example, the reinforcement learning environment is a pole attached to an unactuated revolutionary joint on a cart. The cart has an actuated prismatic joint connected to a one-dimensional frictionless track. The training goal in this environment is to balance the pole by applying forces (actions) to the prismatic joint.

For this environment:

  • The upward balanced pendulum position is 0 radians and the downward hanging position is pi radians.

  • The pendulum starts upright with an initial angle between –0.05 radians and 0.05 radians.

  • The force action signal from the agent to the environment is from –10 N to 10 N.

  • The observations from the environment are the position and velocity of the cart, the pendulum angle, and the pendulum angle derivative.

  • The episode terminates if the pole is more than 12 degrees from vertical or if the cart moves more than 2.4 m from the original position.

  • A reward of +0.5 is provided for every time-step that the pole remains upright. An additional reward is provided based on the distance between the cart and the origin. A penalty of –50 is applied when the pendulum falls.

For more information on this model, see Load Predefined Control System Environments.

Create a predefined environment interface for the cart-pole system.

env = rlPredefinedEnv("CartPole-Continuous");

The interface has a continuous action space where the agent can apply one force value ranging from –10 N to 10 N.

Obtain the observation and action specifications from the environment interface.

obsInfo = getObservationInfo(env);
numObservations = obsInfo.Dimension(1);
actInfo = getActionInfo(env);

Fix the random generator seed for reproducibility.

rng(0);

Create MBPO Agent

An MBPO agent decides which action to take given observations using a base off-policy agent. The MBPO agent trains both the base agent and an environmental model. The environmental model consists of transition functions, a reward function, and an is-done function. This model is used to create more samples without interacting with an environment. This example uses the following steps to construct an MBPO agent.

  1. Define model-free off-policy agent.

  2. Define transition models.

  3. Define reward model.

  4. Define is-done model.

  5. Create neural network environment.

  6. Create MBPO agent.

1. Define Model-Free Off-Policy Agent

Create a SAC base agent with a default network structure. For more information on SAC agents, see Soft Actor-Critic (SAC) Agent. For an environment with a continuous action space, you can also use a DDPG or TD3 base agent. For discrete environments, you can use a DQN base agent.

agentOpts = rlSACAgentOptions;
agentOpts.MiniBatchSize = 256;
agentOpts.ExperienceBufferLength = 1e6;
agentOpts.NumEpoch = 1;
initOpts = rlAgentInitializationOptions(NumHiddenUnit=128);
baseagent = rlSACAgent(obsInfo,actInfo,initOpts,agentOpts);

baseagent.AgentOptions.ActorOptimizerOptions.LearnRate = 5e-4;
baseagent.AgentOptions.ActorOptimizerOptions.GradientThreshold = 1;
baseagent.AgentOptions.CriticOptimizerOptions(1).LearnRate = 5e-4;
baseagent.AgentOptions.CriticOptimizerOptions(1).GradientThreshold = 1;
baseagent.AgentOptions.CriticOptimizerOptions(2).LearnRate = 5e-4;
baseagent.AgentOptions.CriticOptimizerOptions(2).GradientThreshold = 1;

2. Define Transition Models

To model the environment, an MBPO agent trains one or more transition models. To model an environment effectively, you must consider two kinds of uncertainty: statistical uncertainty and modeling uncertainty. A stochastic transition function can model the statistical uncertainty better than a deterministic transition function. In this example, since the cart-pole environment is deterministic, you use deterministic transition functions.

It is challenging to have a perfect model, and a trained model usually has modeling uncertainty. One common approach to overcoming modeling uncertainty is to use multiple transition models. The original MBPO paper uses seven models [1]. For this example, to reduce computational cost, you use three models. The MBPO agent generates experiences using all three transition models. The following figure shows how an ensemble of transition models generates samples without interacting with the environment. In this figure, the models generate two trajectories with horizon = 2.

Create three deterministic transition functions. To do so, create a deep neural network using the createDeterministicTransitionNetwork helper function. Then, use the neural network to create an rlContinuousDeterministicTransitionFunction object. When creating a transition function object, you must specify the action and observation input/output names for the neural network.

net1 = createDeterministicTransitionNetwork(4,1);
transitionFcn1 = rlContinuousDeterministicTransitionFunction(net1,...
    obsInfo,...
    actInfo,...
    ObservationInputNames="state",...
    ActionInputNames="action",...
    NextObservationOutputNames="nextObservation");

net2 = createDeterministicTransitionNetwork(4,1);
transitionFcn2 = rlContinuousDeterministicTransitionFunction(net2,...
    obsInfo,...
    actInfo,...
    ObservationInputNames="state",...
    ActionInputNames="action",...
    NextObservationOutputNames="nextObservation");

net3 = createDeterministicTransitionNetwork(4,1);
transitionFcn3 = rlContinuousDeterministicTransitionFunction(net3,...
    obsInfo,...
    actInfo,...
    ObservationInputNames="state",...
    ActionInputNames="action",...
    NextObservationOutputNames="nextObservation");

3. Define Reward Model

An MBPO agent also contains a reward model for the environment. If you know a ground-truth reward function, you can specify it using a custom function. In this example, the ground-truth reward function is defined in the cartPoleRewardFunction helper function. To use this reward function set useGroundTruthReward to true.

You can also specify a neural-network-based reward function that the MBPO agent can train. In this example, you can use such a reward function by setting useGroundTruthReward to false. The deep neural network for the reward function is defined in the createRewardNetworkActionNextObs helper function. To define an is-done function using the neural network, create an rlContinuousDeterministicRewardFunction object.

useGroundTruthReward = true;
if useGroundTruthReward
    rewardFcn = @cartPoleRewardFunction;
else
    % This neural network uses action and next observation as inputs.
    rewardnet = createRewardNetworkActionNextObs(4,1);
    rewardFcn = rlContinuousDeterministicRewardFunction(rewardnet,...
        obsInfo,...
        actInfo, ...
        ActionInputNames="action",...
        NextObservationInputNames="nextState");
end

4. Define Is-Done Model

An MBPO agent also contains an is-done model for computing the termination signal for the environment. If you know a ground-truth termination signal, you can specify it using a custom function. In this example, the ground-truth termination signal is defined in the cartPoleIsDoneFunction helper function. To use this reward function set useGroundTruthIsDone to true.

You can also specify a neural-network-based is-done function that the MBPO agent can train. In this example, you can use such an is-done function by setting useGroundTruthIsDone to false. The deep neural network for the is-done function is defined in the createIsDoneNetwork helper function. To define an is-done function using the neural network, create an rlIsDoneFunction object.

useGroundTruthIsDone = true;
if useGroundTruthIsDone
    isdoneFcn = @cartPoleIsDoneFunction;
else
    % This neural network uses only next obesrvation as inputs.
    isdoneNet = createIsDoneNetwork(4);
    isdoneFcn = rlIsDoneFunction(isdoneNet,...
        obsInfo,...
        actInfo,...
        NextObservationInputNames="nextState");
end

5. Create Neural Network Environment

Define a neural network environment using the transition, reward, and is-done functions. To do so, create an rlNeuralNetworkEnvironment object.

generativeEnv = rlNeuralNetworkEnvironment( ...
    obsInfo, ...
    actInfo, ...
    [transitionFcn1,transitionFcn2,transitionFcn3], ...
    rewardFcn, ...
    isdoneFcn);

% Reset model environment.
reset(generativeEnv);

6. Create MBPO Agent

Define an MBPO agent using the base off-policy agent and environment model. To do so, first create an MBPO agent options object.

MBPOAgentOpts = rlMBPOAgentOptions;

Specify options for training the environment model. Train the model for 1 epoch at the beginning of each episode and use 15 mini-batches of size 512.

MBPOAgentOpts.NumEpochForTrainingModel = 1;
MBPOAgentOpts.NumMiniBatches = 15;
MBPOAgentOpts.MiniBatchSize = 512;

Specify the ratio of real and generated experience used to train the base SAC agent. For this example, 20% of samples are from the real experience buffer and 80% of samples are from model experience buffer.

MBPOAgentOpts.RealSampleRatio = 0.2;

Specify options for generating samples using the environment model.

  • Generate 20000 trajectories at the beginning of each epoch.

  • Use a piecewise roll-out horizon schedule, which increases the horizon length gradually.

  • Increase the horizon length every 100 epochs.

  • Use an initial horizon length of 1.

  • Use a maximum horizon length of 2.

MBPOAgentOpts.ModelRolloutOptions.NumRollout = 20000;
MBPOAgentOpts.ModelRolloutOptions.HorizonUpdateSchedule = "piecewise";
MBPOAgentOpts.ModelRolloutOptions.HorizonUpdateFrequency = 100;
MBPOAgentOpts.ModelRolloutOptions.Horizon = 1;
MBPOAgentOpts.ModelRolloutOptions.HorizonMax = 2;

Specify the size of the model's experience buffer.

MBPOAgentOpts.ModelExperienceBufferLength = 120000;

Specify optimizer options for training the transition models. Use the same optimizer options for all three transition models.

transitionOptimizerOptions1 = rlOptimizerOptions(...
    LearnRate=5e-4,...
    GradientThreshold=1.0);
transitionOptimizerOptions2 = rlOptimizerOptions(...
    LearnRate=5e-4,...
    GradientThreshold=1.0);
transitionOptimizerOptions3 = rlOptimizerOptions(...
    LearnRate=5e-4,...
    GradientThreshold=1.0);
MBPOAgentOpts.TransitionOptimizerOptions = ...
    [transitionOptimizerOptions1,...
    transitionOptimizerOptions2,...
    transitionOptimizerOptions3];

Specify optimizer options for training the reward model. If you use a custom ground-truth reward function, the agent ignores these options.

rewardOptimizerOptions = rlOptimizerOptions(...
    LearnRate=5e-4,...
    GradientThreshold=1.0);
MBPOAgentOpts.RewardOptimizerOptions = rewardOptimizerOptions;

Specify optimizer options for training the is-done model. If you use a custom ground-truth reward function, the agent ignores these options.

isdoneOptimizerOptions = rlOptimizerOptions(...
    LearnRate=5e-4,...
    GradientThreshold=1.0);
MBPOAgentOpts.IsDoneOptimizerOptions = isdoneOptimizerOptions;

Create the MBPO agent, specifying the base agent, environment model, and options.

agent = rlMBPOAgent(baseagent,generativeEnv,MBPOAgentOpts);

Train Agent

To train the agent, first specify the training options. For this example, use the following options.

  • Run each training episode for at most 500 episodes, with each episode lasting at most 500 time steps.

  • Display the training progress in the Reinforcement Learning Training Monitor dialog box (set the Plots option) and disable the command line display (set the Verbose option to false).

  • Save the agent when the average episode reward is greater than or equal to 470.

  • Stop training when an evaluation statistic reaches 470. At this point, the agent can balance the pendulum in the upright position.

For more information, see rlTrainingOptions.

trainOpts = rlTrainingOptions(...
    MaxEpisodes=500, ...
    MaxStepsPerEpisode=500, ...
    Verbose=false, ...
    Plots="training-progress",...
    StopTrainingCriteria="EvaluationStatistic",...
    StopTrainingValue=470,...
    ScoreAveragingWindowLength=5,...
    SaveAgentCriteria="EpisodeReward",...
    SaveAgentValue=470); 

Create an evaluator object to evaluate the agent using 5 evaluation episodes at every 25 training episodes.

evl = rlEvaluator( ...
    NumEpisodes=5, ...
    EvaluationFrequency=25, ...
    RandomSeeds=101:105);

You can visualize the cart-pole system by using the plot function during training or simulation.

plot(env)

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

Train the agent using the train function. Training this agent is a computationally-intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining to false. To train the agent yourself, set doTraining to true.

doTraining = false;
if doTraining    
    % Train the agent.    
    trainingStats = train(agent,env,trainOpts,Evaluator=evl);
else
    % Load the pretrained agent for the example.
    load("MATLABCartpoleMBPO.mat","agent");
end

Simulate MBPO Agent

To validate the performance of the trained agent, simulate it within the cart-pole environment. For more information on agent simulation, see rlSimulationOptions and sim. Exploration during validation is not necessary in this example. Therefore, to use deterministic actions during the simulation, set the UseExplorationPolicy agent property to agent to be false.

rng(1)

% Disable exploration during sim
agent.UseExplorationPolicy = false; 

simOptions = rlSimulationOptions(MaxSteps=500);
experience = sim(env,agent,simOptions);
totalReward_MBPO = sum(experience.Reward)
totalReward_MBPO = 
497.5506

Instead of simulating the MBPO agent, you can simulate the base agent. If you use the same random seed, you get the same result as simulating the MBPO agent.

rng(1)
experience = sim(env,agent.BaseAgent,simOptions);

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

totalReward_SAC = sum(experience.Reward)
totalReward_SAC = 
497.5506

Evaluate Learned Environment Model

To validate the trained environment transition models, you can check whether they are able to correctly predict the next observations. Similarly, you can validate the performance of the reward and is-done functions. To make a prediction based on the environment model, use step.

Collect data for learned model evaluation.

rng(1)

% Enable exploration during sim to create 
% diverse data for model evaluation
agent.UseExplorationPolicy = true; 

simOptions = rlSimulationOptions(MaxSteps=500);
experience = sim(env,agent,simOptions);

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

For this example, evaluate the performance of the first transition model.

agent.EnvModel.TransitionModelNum = 1;

For each simulation step, extract the actual next observation.

numSteps = length(experience.Reward.Data);
nextObsPrediction = zeros(4,1,numSteps);
rewardPrediction = zeros(1,numSteps);
isdonePrediction = zeros(1,numSteps);
nextObsGroundTruth = zeros(4,1,numSteps);
rewardGroundTruth = zeros(1,numSteps);
isdoneGroundTruth = zeros(1,numSteps);
for stepCt = 1:numSteps
    % Extract the actual next observation, reward, and is-done value.
    nextObsGroundTruth(:,:,stepCt) = ...
        experience.Observation.CartPoleStates.Data(:,:,stepCt+1);
    rewardGroundTruth(:, stepCt) = experience.Reward.Data(stepCt);
    isdoneGroundTruth(:, stepCt) = experience.IsDone.Data(stepCt);

    % Predict the next observation, reward, and is-done value 
    % using the environment model.
    obs = experience.Observation.CartPoleStates.Data(:,:,stepCt);
    agent.EnvModel.Observation = {obs};
    action = experience.Action.CartPoleAction.Data(:,:,stepCt);
    [nextObs,reward,isdone] = step(agent.EnvModel,{action});

    nextObsPrediction(:,:,stepCt) = nextObs{1};
    rewardPrediction(:,stepCt) = reward;
    isdonePrediction(:,stepCt) = isdone;
end

Plot the ground truth and prediction of each dimension of the observations.

figure
for obsDimensionIndex = 1:4
    subplot(2,2,obsDimensionIndex)
    plot(reshape(nextObsGroundTruth(obsDimensionIndex,:,:),1,numSteps))
    hold on
    plot(reshape(nextObsPrediction(obsDimensionIndex,:,:),1,numSteps))
    hold off
    xlabel("Step")
    ylabel("Observation")   
    if obsDimensionIndex == 4
        legend("GroundTruth","Prediction","Location","northeast")
    end
end

Figure contains 4 axes objects. Axes object 1 with xlabel Step, ylabel Observation contains 2 objects of type line. Axes object 2 with xlabel Step, ylabel Observation contains 2 objects of type line. Axes object 3 with xlabel Step, ylabel Observation contains 2 objects of type line. Axes object 4 with xlabel Step, ylabel Observation contains 2 objects of type line. These objects represent GroundTruth, Prediction.

References

[1] Janner, Michael, Justin Fu, Marvin Zhang, and Sergey Levine. “When to Trust Your Model: Model-Based Policy Optimization.” In Proceedings of the 33rd International Conference on Neural Information Processing Systems, 12519–30. 1122. Red Hook, NY, USA: Curran Associates Inc., 2019.

See Also

Functions

Objects

Related Examples

More About