Main Content

Train Agent to Play Turn-Based Game

This example shows you how to train a deep Q-network (DQN) reinforcement learning agent to play a turn-based multi-player game.

Fix Random Seed Generator to Improve Reproducibility

The example code may involve computation of random numbers at various stages such as initialization of the agent, creation of the actor and critic, resetting the environment during simulations, generating observations (for stochastic environments), generating exploration actions, and sampling min-batches of experiences for learning. Fixing the random number stream preserves the sequence of the random numbers every time you run the code and improves reproducibility of results. You will fix the random number stream at various locations in the example. Fix the random number stream with the seed 0 and random number algorithm Mersenne twister. For more information on random number generation see rng.

previousRngState = rng(0,"twister")
previousRngState = struct with fields:
     Type: 'twister'
     Seed: 0
    State: [625x1 uint32]

The output previousRngState is a structure that contains information about the previous state of the stream. You will restore the state at the end of the example.

Overview

Environments in which agent execution is controlled by a turn-based logic are known as turn-based environments. Turns are analogous to environment steps, and one or more agents may take actions at a given turn. For example, in a two-player game such as chess, each turn requires one player (or agent) to take action.

In this example, you create a turn-based multi-agent environment using MATLAB® functions and train an agent to play against a policy that takes random actions.

The environment in this example is a simple turn-based two-player game. For this environment:

  • The game contains a 3-by-3 grid with nine cells.

  • Two adversarial players (agents) play the game. At each turn, a player marks a location in the grid. Player one has red squares and player two has blue circles.

  • The observation of each agent is a vector obtained from the 3-by-3 grid by flattening it in column-major fashion. The vector contains zeros in unmarked locations, ones in locations marked by the agent, and negative ones in locations marked by the opponent agent.

  • The action of an agent is the grid location to be marked (an integer between one and nine, inclusive).

  • The game is won if an agent has marked three consecutive cells (a horizontal row, vertical column, or any diagonal).

The reward function for the active agent is described below.

  • A positive reward 10 if the agent wins the game in the current turn.

  • A positive reward of 1 if the agent has marked two consecutive grid cells with the third cell in line unmarked, setting up a possible win in the next turn.

  • A positive reward of 10 if the agent marks a cell that prevents the opponent from winning in the next turn.

  • A negative reward of -10 if the agent fails to mark a cell that enables the opponent to win in the next turn.

  • A negative reward of -10 if the agent takes an illegal move like marking an occupied cell. The simulation is also terminated in this case.

The inactive agent receives the reward 0 always.

Create Environment Object

First, create the observation specification for each agent. The observation is of a vector of 9 elements.

oinfo = rlNumericSpec([9,1]);

Next, create the action specification for each agent, which is a set of discrete cell indices.

ainfo = rlFiniteSetSpec({1;2;3;4;5;6;7;8;9});

The environment's observation and action specifications are cell arrays. Create the arrays.

obsInfo = {oinfo, oinfo};
actInfo = {ainfo, ainfo};

Create a turn-based environment using rlTurnBasedFunctionEnv. The command takes as input arguments the observation and action specifications of the agent and function handles for the step and reset operations.

Create the environment object. The functions stepGame and resetGame are provided at the end of this example.

env = rlTurnBasedFunctionEnv(obsInfo, actInfo, @stepGame, @resetGame)
env = 
  rlTurnBasedFunctionEnv with properties:

     StepFcn: @stepGame
    ResetFcn: @resetGame
        Info: [1x1 struct]

Add a listener for changes to the Info property of the environment. When the Info property is updated during simulation, the function hPlotGame (provided at the end of this example) plots the updated environment. Turn off the visualization to improve the training performance. To turn on visualization, set doPlot to true.

doPlot = false;
if doPlot
    addlistener(env,"Info","PostSet",@(~,~) hPlotGame(env));
end

Create Agents

Two adversarial agents play the game.

  • An agent generating random allowed actions controls the blue circles in the game. This agent is not trained.

  • A deep Q-network (DQN) agent controls the red squares in the game. You train this agent to win against the random policy.

Create a custom agent object using the AgentWithRandomActions class. For more information see Create Custom Reinforcement Learning Agents.

randomAgent = AgentWithRandomActions(oinfo,ainfo);

View the implementation of AgentWithRandomActions.

type("AgentWithRandomActions.m")
classdef AgentWithRandomActions < rl.agent.CustomAgent
    % AgentWithRandomActions models a custom agent with random actions.

    % Copyright 2023 The MathWorks, Inc.

    properties (Access = private)
        SampleTime_
    end

    methods
        function this = AgentWithRandomActions(obsinfo, actinfo)
            setObservationInfo_(this, obsinfo);
            setActionInfo_(this,actinfo);
            this.SampleTime = 1;
        end
    end

    methods (Access = protected)
        function learnImpl(~,~)
            % no op because the agent does not learn
        end

        function action = getActionWithExplorationImpl(this, obs)
            action = getActionImpl(this, obs);
        end
        
        function action = getActionImpl(~, obs)
            % Generate random actions.
            x = obs{1};
            % legal moves are positions which are unmarked (value is 0)
            legalmoves = find(x==0);
            % choose a random legal move
            randidx = randperm(numel(legalmoves),1);
            action = legalmoves(randidx);
        end

        function ts = getSampleTime_(this)
            ts = this.SampleTime_;
        end

        function this = setSampleTime_(this,ts)
            this.SampleTime_ = ts;
        end
    end
end

Next, specify options for training the DQN agent. The agent learns using the double deep Q-network algorithm (by setting UseDoubleDQN to true) with a mini-batch size of 64, target smoothing factor of 0.01, and a discount factor of 0.99 which favors long-term rewards. For more information see rlDQNAgentOptions.

agentOpts = rlDQNAgentOptions( ...
    UseDoubleDQN=true, ...
    MiniBatchSize=64, ...
    TargetSmoothFactor=1e-2, ...
    DiscountFactor=0.99);

Specify options for the critic optimization algorithm. The critic learns using the adam algorithm with a learning rate of 1e-3.

agentOpts.CriticOptimizerOptions.Algorithm = "adam";
agentOpts.CriticOptimizerOptions.LearnRate = 1e-3;

Specify exploration options for training. The agent uses an epsilon-greedy exploration strategy with an initial Epsilon value of 0.9, which exponentially decays at the rate of 1e-4 until it reaches the value of 0.01.

agentOpts.EpsilonGreedyExploration.Epsilon = 0.9;
agentOpts.EpsilonGreedyExploration.EpsilonDecay = 1e-4;
agentOpts.EpsilonGreedyExploration.EpsilonMin = 0.01;

Set the length of the experience buffer as 1e6. This ensures that the agent learns from a large set of experiences.

agentOpts.ExperienceBufferLength = 1e6;

Create initialization options for the agent. The agent's critic uses a neural network model with a hidden layer size of 64.

initOpts = rlAgentInitializationOptions(NumHiddenUnit=64);

Create a default DQN agent object. When you create the agent, the initial parameters of the critic network are initialized with random values. Fix the random number stream so that the agent is always initialized with the same parameter values. For more information see rlDQNAgentOptions.

rng(0,"twister");
agent = rlDQNAgent(oinfo, ainfo, initOpts, agentOpts);

View the critic network of the agent.

critic = getCritic(agent);
criticNet = getModel(critic);
plot(criticNet);

Figure contains an axes object. The axes object contains an object of type graphplot.

View a summary of the critic network.

summary(criticNet);
   Initialized: true

   Number of learnables: 5.3k

   Inputs:
      1   'input_1'   9 features

Training

In this example, only the agent controlling the red squares is trained.

To evaluate the simulation behavior of the agent during training, configure an evaluator object that periodically computes an evaluation score. In this example the evaluator object runs 25 evaluation simulations every 100 training episodes with different random seeds (0 to 24) and computes the minimum episode reward as the evaluation score. The evaluation score may be used as a criterion to save agents during training. Although the agent evaluations are run with different random seeds, the random stream of the training process is not affected.

For more information on evaluation objects, see rlEvaluator.

evaluator = rlEvaluator( ...
    EvaluationFrequency=100, ...
    EvaluationStatisticType="MinEpisodeReward", ...
    NumEpisodes=25, ...
    RandomSeeds=0:24);

For this training session:

  • Run the training for 5000 episodes.

  • Specify AgentGroups as a cell array of agent indices. Omit the index of the agent that is not trained. Note that the index must correspond to the order of agents specified in the train function.

  • Compute the average episodic rewards using a moving window of size 20.

  • Save the agents from the episodes where the evaluation score is 9 or higher. This score is close to the reward received (10) when the game is won and may indicate that the policy has learned sufficiently.

  • Turn off visualization during training (set the Plots option) and display the training progress in the command window using the Verbose option.

  • View the training progress plot after the training is finished using the show command.

trainOpts = rlMultiAgentTrainingOptions( ...
    AgentGroups={1}, ...
    MaxEpisodes=5000, ...
    ScoreAveragingWindowLength=20, ...
    StopTrainingCriteria="none", ...
    SaveAgentCriteria="EvaluationStatistic", ...
    SaveAgentValue=9, ...
    Plots="none",...
    Verbose=true);

For more information see rlMultiAgentTrainingOptions.

Fix the random stream for reproducibility.

rng(0,"twister");

Train the agents using the train function. Training can take several hours to complete depending on the available computational power. To save time, load the MAT-file twoPlayerGameAgent.mat, which contains a set of pretrained agents. To train the agents yourself, set doTraining to true.

doTraining = false;
if doTraining
    % train
    trainResults = train([agent,randomAgent], env, ...
        trainOpts, Evaluator=evaluator);
    % show the training progress visualization
    show(trainResults);
else
    load("twoPlayerGameAgent.mat");
end

A snapshot of the training progress is shown in the following figure. You may see different results from your training process.

Simulation

Fix the random stream for reproducibility.

rng(0,"twister");

Specify the function hPlotGame for environment visualization.

addlistener(env,"Info","PostSet",@(~,~) hPlotGame(env));

Simulate the trained agent with the environment. For more information on agent simulation, see rlSimulationOptions and sim.

simOptions = rlSimulationOptions();
experience = sim(env,[agent,randomAgent],simOptions);

Figure Simple Two Player Game contains 9 axes objects. Axes object 1 with title Turn 1 contains 5 objects of type constantline, rectangle. Axes object 2 with title Turn 2 contains 6 objects of type constantline, rectangle. Axes object 3 with title Turn 3 contains 7 objects of type constantline, rectangle. Axes object 4 with title Turn 4 contains 8 objects of type constantline, rectangle. Axes object 5 with title Turn 5 contains 9 objects of type constantline, rectangle. Axes object 6 with title Turn 6 contains 10 objects of type constantline, rectangle. Axes object 7 with title Turn 7 contains 11 objects of type constantline, rectangle. Axes object 8 with title Turn 8 contains 4 objects of type constantline. Axes object 9 with title Turn 9 contains 4 objects of type constantline.

The visualization shows the sequence of turns for the game. The trained agent has learned the actions to win the game.

Restore the random number stream using the information stored in previousRngState.

rng(previousRngState);

Local Functions

Custom reset function for the two-player game. The reset function initializes the environment, randomly selects an agent for the first turn, and returns the initial observations of the agents along with the info variable that passes information between simulation steps.

function [initialObs, info] = resetGame()
% Reset the two-player game.

% The state is a 3x3 matrix with 0s in unmarked cells, 
% -1s in cells marked with squares and 
% 1s in cells marked with circles.
info.State = zeros(3,3);

% The initial turn is randomly selected.
info.ActiveAgentIndex = 1; %randperm(2,1);

% The current environment step count.
info.StepCount = 0;

% Flag to keep track of invalid action.
info.IsInvalidAction = false;

% Get the active agent for the current turn.
selfIdx     = info.ActiveAgentIndex;
opponentIdx = mod(selfIdx,2) + 1;
playerVals  = [-1,1];
selfVal     = playerVals(selfIdx);
opponentVal = playerVals(opponentIdx);

% Initial observation is the state of the environment 
% as seen by each agent:
% self cells 1, opponent cells -1, unmarked 0
selfObs = ...
    double(info.State==selfVal) - double(info.State==opponentVal);
opponentObs = ...
    -double(info.State==selfVal) + double(info.State==opponentVal);

initialObs{selfIdx} = selfObs(:);
initialObs{opponentIdx} = opponentObs(:);
end

Custom step function for the two-player game. This function steps the environment dynamics to the next state.

function [nextobs, reward, isdone, info] = stepGame(action, info)

% Get the active agent for the current turn.
selfIdx = info.ActiveAgentIndex;

% Compute the next agent index.
opponentIdx = mod(selfIdx,2) + 1;

% Player cell values:
% Player 1 (red square) is identified by the value -1.
% Player 2 (blue circle) is identified by the value 1.
playerVals = [-1,1];

% Get current player value.
selfVal  = playerVals(selfIdx);
opponentVal = playerVals(opponentIdx);

% Index of cell for the current player.
% action{1} is the move of the current player, 
% that is, the index (between 1 and 9) 
% of new grid location to be marked.
playerCell = action{1};

% Advance environment to the next state.
state = info.State;
if state(playerCell) == 0
    % Move to next state.
    state(playerCell) = selfVal;
    % Compute reward and terminal condition.
    [rwd, isdone, iswin] = hComputeRewardAndIsDone( ...
        state, playerCell, selfVal, opponentVal);
else
    % For an illegal action,
    % terminate with a large penalty.
    rwd = -20;
    isdone = true;
    iswin = false;
    info.IsInvalidAction = true;
end

% The active agent receives the reward rwd, 
% and the other agent receives the reward 0.
reward = [0,0];
reward(selfIdx) = rwd;

% Optional: To make this a zero sum game, 
% uncomment the following line.
% reward(opponentIdx)  = -rwd;

% Next observation.
% self cells 1
% opponent cells -1
% unmarked 0
selfObs = ...
    double(state==selfVal) - double(state==opponentVal);
opponentObs = ...
    -double(state==selfVal) + double(state==opponentVal);

nextobs{selfIdx} = selfObs(:);
nextobs{opponentIdx} = opponentObs(:);

% Update the info structure with:
%   1. The state for the next step.
%   2. The agent turn for the next step.
%   3. The step count.
info.State = state;
info.ActiveAgentIndex = opponentIdx;
info.StepCount = info.StepCount + 1;
info.IsWin = iswin;
end

Helper function to compute reward and terminal conditions.

function [reward, isover, iswin] = hComputeRewardAndIsDone( ...
    state, index, player, opponent)
% Advance to the next state of the game by marking a cell.
%
% Reward for active agent:
%
% +10 if the agent wins the game in the current turn.
% 
% +1 if the agent has marked two consecutive grid cells 
% with the third cell in line unmarked.
% 
% +10 if the agent marks a cell that prevents
% the opponent from winning in the next turn.
% 
% -10 if the agent fails to mark a cell 
% that enables the opponent to win in the next turn.
% 
% -10 if the agent takes an illegal move 
% like marking an occupied cell.
% 
% 0 otherwise.

sz = size(state);
reward = 0;
isover = false;
iswin = 0;

% Current row and column of active agent.
[r,c] = ind2sub(sz, index);

    function nestedComputeReward_(arr)
        % arr is a row or column or a diagonal.
        if all(arr==player)
            % Row/column/diagonal complete, game over.
            reward = reward + 10;
            isover = true;
            iswin = true;
        elseif nnz(arr==player)==2 && any(arr==0)
            % Player marks adjacent cells,
            % possible win next turn.
            reward = reward + 1;
        elseif nnz(arr==opponent)==2
            % Player blocks opponent win.
            reward = reward + 10;
        end
    end

% Horizontal cells.
hcells = state(r,:);
nestedComputeReward_(hcells);

% Vertical cells.
vcells = state(:,c);
nestedComputeReward_(vcells);

% Check the two diagonals:
% r==c is the main diagonal,
% r+c==4 is the other diagonal.
if r==c || r+c==4
    % Main diagonal cells.
    mdcells = [state(1,1) state(2,2) state(3,3)];
    nestedComputeReward_(mdcells);
    % Other diagonal cells.
    odcells = [state(3,1) state(2,2) state(1,3)];
    nestedComputeReward_(odcells);
end

% If all cells are marked the game is over.
if ~any(state==0)
    isover = true;
end

% penalty if opponent is winning next turn
if ~isover
    ishcellWin = any(cellfun(@(x) nnz(x==opponent)==2 && any(x==0),...
        {state(1,:),state(2,:),state(3,:)}));
    isvcellWin = any(cellfun(@(x) nnz(x==opponent)==2 && any(x==0),...
        {state(:,1),state(:,2),state(:,3)}));
    isdiag1Win = any(cellfun(@(x) nnz(x==opponent)==2 && any(x==0),...
        {[state(1,1),state(2,2),state(3,3)]}));
    isdiag2Win = any(cellfun(@(x) nnz(x==opponent)==2 && any(x==0),...
        {[state(3,1),state(2,2),state(1,3)]}));
    if ishcellWin || isvcellWin || isdiag1Win || isdiag2Win
        reward = -20;
    end
end
end

Helper function for visualizing the environment.

function hPlotGame(env)
% Plot the Simple Two Player Game environment.

persistent f ax hrect
state  = env.Info.State;
stepCt = env.Info.StepCount;
isdone = env.Info.IsInvalidAction;
if isempty(f) || ~isvalid(f)
    f = figure(...
        Toolbar="auto", ...
        NumberTitle="off", ...
        Name="Simple Two Player Game",...
        MenuBar="none", ...
        Visible="on");
    f.Position(3:4) = [1000,250];

    t = tiledlayout( ...
        f, ...
        "horizontal", ...
        TileSpacing="compact", ...
        Padding="compact");

    for ct = 1:9
        hax = nexttile(t);
        axis(hax,"equal");
        title(hax,"Turn "+ct);
        hax.XLim = [-15 15];
        hax.YLim = [-15 15];
        hax.XTick = [];
        hax.YTick = [];
        hax.Box = "on";
        hold(hax,"on");
        xline(hax,-5);
        xline(hax,5);
        yline(hax,-5);
        yline(hax,5);
        ax{ct} = hax;
    end
end
if stepCt<=1
    delete(hrect);
    hrect = [];
end
if ~isdone
    for ct = 1:9
        if ismember(ct,[1,2,3])
            x = -10;
        elseif ismember(ct,[4,5,6])
            x = 0;
        else
            x = 10;
        end
        if ismember(ct,[1,4,7])
            y = 10;
        elseif ismember(ct,[2,5,8])
            y = 0;
        else
            y = -10;
        end
        if state(ct)==-1
            hr = rectangle( ...
                ax{stepCt}, ...
                Position=[x-2.5,y-2.5,5,5], ...
                FaceColor="r", ...
                EdgeColor="none");
            hrect = [hrect; hr];
        elseif state(ct)==1
            hr = rectangle( ...
                ax{stepCt}, ...
                Position=[x-2.5,y-2.5,5,5], ...
                Curvature=[1 1], ...
                FaceColor="b", ...
                EdgeColor="none");
            hrect = [hrect; hr];
        end
    end
end

drawnow();
end

See Also

Functions

Objects

Related Examples

More About