Main Content

Train AC Agent to Balance Cart-Pole System Using Parallel Computing

This example shows how to train an actor-critic (AC) agent to balance a cart-pole system modeled in MATLAB® by using asynchronous parallel training. For an example that shows how to train the agent without using parallel training, see Train AC Agent to Balance Cart-Pole System.

Actor Parallel Training

When you use parallel computing with AC agents, each worker generates experiences from its copy of the agent and the environment. After every N steps, the worker computes gradients from the experiences and sends the computed gradients back to the client agent (the agent associated with the MATLAB® process which starts the training). The client averages the gradients, updates the network parameters and sends the updated parameters back to the workers to they can continue simulating the agent with the new parameters.

This type of parallel training is also known as gradient-based parallelization, and allows you to achieve, in principle, a speed improvement which is nearly linear in the number of workers. However, this option requires synchronous training (that is the Mode property of the rlTrainingOptions object that you pass to the train function must be set to sync). This means that workers must pause execution until all workers are finished, and as a result the training only advances as fast as the slowest worker allows.

For more information about synchronous versus asynchronous parallelization, see Train Agents Using Parallel Computing and GPUs.

Create Cart-Pole MATLAB Environment Interface

Create a predefined environment interface for the cart-pole system. For more information on this environment, see Load Predefined Control System Environments.

env = rlPredefinedEnv("CartPole-Discrete");
env.PenaltyForFalling = -10;

Obtain the observation and action information from the environment interface.

obsInfo = getObservationInfo(env);
numObservations = obsInfo.Dimension(1);
actInfo = getActionInfo(env);

Fix the random generator seed for reproducibility.


Create AC Agent

Actor-critic agents use a parametrized value function approximator to estimate the value of the policy. A value-function critic takes the current observation as input and returns a single scalar as output (the estimated discounted cumulative long-term reward for following the policy from the state corresponding to the current observation).

To model the parametrized value function within the critic, use a neural network with one input layer (which receives the content of the observation channel, as specified by obsInfo) and one output layer (which returns the scalar value). Note that prod(obsInfo.Dimension) returns the total number of dimensions of the observation space regardless of whether the observation space is a column vector, row vector, or matrix.

Define the network as an array of layer objects.

criticNetwork = [

Create the value function approximator object using criticNetwork and the environment action and observation specifications.

critic = rlValueFunction(criticNetwork,obsInfo);

Actor-critic agents use a parametrized stochastic policy, which for discrete action spaces is implemented by a discrete categorical actor. This actor takes an observation as input and returns as output a random action sampled (among the finite number of possible actions) from a categorical probability distribution.

To model the parametrized policy within the actor, use a neural network with one input layer (which receives the content of the environment observation channel, as specified by obsInfo) and one output layer. The output layer must return a vector of probabilities for each possible action, as specified by actInfo. Note that numel(actInfo.Dimension) returns the number of elements of the discrete action space.

Define the network as an array of layer objects.

actorNetwork = [

Create the actor approximator object using actorNetwork and the environment action and observation specifications.

actor = rlDiscreteCategoricalActor(actorNetwork,obsInfo,actInfo);

For more information on creating approximator objects such as actors and critics, see Create Policies and Value Functions.

Specify options for the critic and actor using rlOptimizerOptions.

criticOpts = rlOptimizerOptions(LearnRate=1e-2,GradientThreshold=1);
actorOpts = rlOptimizerOptions(LearnRate=1e-2,GradientThreshold=1);

Specify the AC agent options using rlACAgentOptions, include the training options for the actor and critic.

agentOpts = rlACAgentOptions(...

Create the agent using the specified actor representation and agent options. For more information, see rlACAgent.

agent = rlACAgent(actor,critic,agentOpts);

Parallel Training Options

To train the agent, first specify the training options. For this example, use the following options.

  • Run each training for at most 1000 episodes, with each episode lasting at most 500 time steps.

  • Display the training progress in the Episode Manager dialog box (set the Plots option) and disable the command line display (set the Verbose option).

  • Stop training when the agent receives an average cumulative reward greater than 500 over 10 consecutive episodes. At this point, the agent can balance the pendulum in the upright position.

trainOpts = rlTrainingOptions(...

You can visualize the cart-pole system can during training or simulation using the plot function.


Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

To train the agent using parallel computing, specify the following training options.

  • Set the UseParallel option to True.

  • Train the agent in parallel asynchronously by setting the ParallelizationOptions.Mode option to "async".

trainOpts.UseParallel = true;
trainOpts.ParallelizationOptions.Mode = "async";

For more information, see rlTrainingOptions.

Train Agent

Train the agent using the train function. Training the agent is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining to false. To train the agent yourself, set doTraining to true. Due to randomness in the asynchronous parallel training, you can expect different training results from the following training plot. The plot shows the result of training with six workers.

doTraining = false;

if doTraining    
    % Train the agent.
    trainingStats = train(agent,env,trainOpts);
    % Load the pretrained agent for the example.

Simulate AC Agent

You can visualize the cart-pole system with the plot function during simulation.


To validate the performance of the trained agent, simulate it within the cart-pole environment. For more information on agent simulation, see rlSimulationOptions and sim.

simOptions = rlSimulationOptions(MaxSteps=500);
experience = sim(env,agent,simOptions);

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

totalReward = sum(experience.Reward)
totalReward = 500


[1] Mnih, Volodymyr, Adrià Puigdomènech Badia, Mehdi Mirza, Alex Graves, Timothy P. Lillicrap, Tim Harley, David Silver, and Koray Kavukcuoglu. ‘Asynchronous Methods for Deep Reinforcement Learning’. ArXiv:1602.01783 [Cs], 16 June 2016.

See Also




Related Examples

More About