rlNeuralNetworkEnvironment
Description
Use an rlNeuralNetworkEnvironment
object to create a
reinforcement learning environment that computes state transitions using deep neural
networks.
Using an rlNeuralNetworkEnvironment
object you can:
Create an internal environment model for a model-based policy optimization (MBPO) agent.
Create an environment for training other types of reinforcement learning agents. You can identify the state-transition network using experimental or simulated data.
Such environments can compute environment rewards and termination conditions using deep neural networks or custom functions.
Creation
Syntax
Description
creates a model for an environment with the observation and action specifications
specified in env
= rlNeuralNetworkEnvironment(ObservationInfo
,ActionInfo
,transitionFcn
,rewardFcn
,isDoneFcn
)ObservationInfo
and ActionInfo
,
respectively. This syntax sets the TransitionFcn
,
RewardFcn
, and IsDoneFcn
properties.
Input Arguments
ObservationInfo
— Observation specifications
rlNumericSpec
object | array rlNumericSpec
objects
This property is read-only.
Observation specifications, specified as an rlNumericSpec
object or an array of such objects. Each element in the array defines the properties
of an environment observation channel, such as its dimensions, data type, and name.
You can extract the observation specifications from an existing environment or
agent using getObservationInfo
. You can also construct the specifications manually
using rlNumericSpec
.
ActionInfo
— Action specifications
rlFiniteSetSpec
object | rlNumericSpec
object
Action specifications, specified as an rlFiniteSetSpec
or rlNumericSpec
object. This object defines the properties of the environment action channel, such as
its dimensions, data type, and name.
Note
For the neural network environment, only one action channel is allowed.
You can extract the action specifications from an existing environment or agent
using getActionInfo
. You can also construct the specification manually using
rlFiniteSetSpec
or rlNumericSpec
.
Properties
TransitionFcn
— Environment transition function
rlContinuousDeterministicTransitionFunction
object | rlContinuousGaussianTransitionFunction
object | array of transition objects
Environment transition function, specified as one of the following:
rlContinuousDeterministicTransitionFunction
object — Use this option when you expect the environment transitions to be deterministic.rlContinuousGaussianTransitionFunction
object — Use this option when you expect the environment transitions to be stochastic.Vector of transition objects — Use multiple transition models for an MBPO agent.
RewardFcn
— Environment reward function
rlContinuousDeterministicRewardFunction
object | rlContinuousGaussianRewardFunction
object | function handle
Environment reward function, specified as one of the following:
rlContinuousDeterministicRewardFunction
object — Use this option when you do not know a ground-truth reward signal for your environment and you expect the reward signal to be deterministic.rlContinuousGaussianRewardFunction
object — Use this option when you do not know a ground-truth reward signal for your environment and you expect the reward signal to be stochastic.Function handle — Use this option when you know a ground-truth reward signal for your environment. When you use an
rlNeuralNetworkEnvironment
object to create anrlMBPOAgent
object, the custom reward function must return a batch of rewards given a batch of inputs.
IsDoneFcn
— Environment is-done function
rlIsDoneFunction
object | function handle
Environment is-done function, specified as one of the following:
rlIsDoneFunction
object — Use this option when you do not know a ground-truth termination signal for your environment.Function handle — Use this option when you know a ground-truth termination signal for your environment. When you use an
rlNeuralNetworkEnvironment
object to create anrlMBPOAgent
object, the custom is-done function must return a batch of termination signals given a batch of inputs.
Observation
— Observation values
cell array
Observation values, specified as a cell array with length equal to the number of
specification objects in ObservationInfo
. The order of the
observations in Observation
must match the order in
ObservationInfo
. Also, the dimensions of each element of the cell
array must match the dimensions of the corresponding observation specification in
ObservationInfo
.
To evaluate whether the transition models are well-trained, you can manually
evaluate the environment for a given observation value using the
step
function. Specify the observation values before calling
step
.
When you use this neural network environment object within an MBPO agent, this property is ignored.
TransitionModelNum
— Transition model index
1 (default) | positive integer
Transition model index, specified as a positive integer.
To evaluate whether the transition models are well-trained, you can manually
evaluate the environment for a given observation value using the
step
function. To select which transition model in
TransitionFcn
to evaluate, specify the transition model index
before calling step
.
When you use this neural network environment object within an MBPO agent, this property is ignored.
Object Functions
rlMBPOAgent | Model-based policy optimization (MBPO) reinforcement learning agent |
Examples
Create Neural Network Environment
Create an environment interface and extract observation and action specifications. Alternatively, you can create specifications using rlNumericSpec
and rlFiniteSetSpec
.
env = rlPredefinedEnv("CartPole-Continuous");
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
Get the dimension of the observation and action spaces.
nObs = obsInfo.Dimension(1); nAct = actInfo.Dimension(1);
Create a deterministic transition function based on a deep neural network with two input channels (current observations and actions) and one output channel (predicted next observation).
% Create network layers. statePath = featureInputLayer(nObs, ... Normalization="none",Name="state"); actionPath = featureInputLayer(nAct, ... Normalization="none",Name="action"); commonPath = [concatenationLayer(1,2,Name="concat") fullyConnectedLayer(64,Name="FC1") reluLayer(Name="CriticRelu1") fullyConnectedLayer(64, Name="FC3") reluLayer(Name="CriticCommonRelu2") fullyConnectedLayer(nObs,Name="nextObservation")]; % Create dlnetwork object and add layers transitionNetwork = dlnetwork(); transitionNetwork = addLayers(transitionNetwork,statePath); transitionNetwork = addLayers(transitionNetwork,actionPath); transitionNetwork = addLayers(transitionNetwork,commonPath); % Connect layers transitionNetwork = connectLayers( ... transitionNetwork,"state","concat/in1"); transitionNetwork = connectLayers( ... transitionNetwork,"action","concat/in2"); % Plot network plot(transitionNetwork)
% Initialize dlnetwork object. transitionNetwork = initialize(transitionNetwork); % Create transition function object. transitionFcn = rlContinuousDeterministicTransitionFunction(... transitionNetwork,obsInfo,actInfo,... ObservationInputNames="state", ... ActionInputNames="action", ... NextObservationOutputNames="nextObservation");
Create a deterministic reward approximator object with two input channels (current action and next observations) and one output channel (predicted reward value).
% Create network layers. nextStatePath = featureInputLayer( ... nObs,Name="nextState"); commonPath = [concatenationLayer(1,3,Name="concat") fullyConnectedLayer(32,Name="fc") reluLayer(Name="relu1") fullyConnectedLayer(32,Name="fc2")]; meanPath = [reluLayer(Name="rewardMeanRelu") fullyConnectedLayer(1,Name="rewardMean")]; stdPath = [reluLayer(Name="rewardStdRelu") fullyConnectedLayer(1,Name="rewardStdFc") softplusLayer(Name="rewardStd")]; % Assemble dlnetwork object. rewardNetwork = dlnetwork(); rewardNetwork = addLayers(rewardNetwork,statePath); rewardNetwork = addLayers(rewardNetwork,actionPath); rewardNetwork = addLayers(rewardNetwork,nextStatePath); rewardNetwork = addLayers(rewardNetwork,commonPath); rewardNetwork = addLayers(rewardNetwork,meanPath); rewardNetwork = addLayers(rewardNetwork,stdPath); % Connect layers rewardNetwork = connectLayers( ... rewardNetwork,"nextState","concat/in1"); rewardNetwork = connectLayers( ... rewardNetwork,"action","concat/in2"); rewardNetwork = connectLayers( ... rewardNetwork,"state","concat/in3"); rewardNetwork = connectLayers( ... rewardNetwork,"fc2","rewardMeanRelu"); rewardNetwork = connectLayers( ... rewardNetwork,"fc2","rewardStdRelu"); % Plot network plot(rewardNetwork)
% Initialize dlnetwork object and display the number of parameters
rewardNetwork = initialize(rewardNetwork);
summary(rewardNetwork)
Initialized: true Number of learnables: 1.4k Inputs: 1 'state' 4 features 2 'action' 1 features 3 'nextState' 4 features
% Create reward function object. rewardFcn = rlContinuousGaussianRewardFunction(... rewardNetwork,obsInfo,actInfo,... ObservationInputNames="state",... ActionInputNames="action", ... NextObservationInputNames="nextState", ... RewardMeanOutputNames="rewardMean", ... RewardStandardDeviationOutputNames="rewardStd");
Create an is-done function approximator object with one input channel (next observations) and one output channel (predicted termination signal).
% Create network layers. isDoneNetwork = [ featureInputLayer(nObs,Name="nextState"); fullyConnectedLayer(64,Name="FC1") reluLayer(Name="CriticRelu1") fullyConnectedLayer(64,Name="FC3") reluLayer(Name="CriticCommonRelu2") fullyConnectedLayer(2,Name="isdone0") softmaxLayer(Name="isdone") ]; % Create dlnetwork object. isDoneNetwork = dlnetwork(isDoneNetwork); % Initialize network and display the number of weights isDoneNetwork = initialize(isDoneNetwork); % Create is-done function approximator object. isDoneFcn = rlIsDoneFunction(isDoneNetwork, ... obsInfo,actInfo, ... NextObservationInputNames="nextState");
Create a neural network environment using the transition, reward, and is-done function approximator objects.
env = rlNeuralNetworkEnvironment( ... obsInfo,actInfo, ... transitionFcn,rewardFcn,isDoneFcn);
Create Neural Network Environment Using Custom Functions
Create an environment interface and extract observation and action specifications. Alternatively, you can create specifications using rlNumericSpec
and rlFiniteSetSpec
.
env = rlPredefinedEnv("CartPole-Continuous");
obsInfo = getObservationInfo(env);
numObservations = obsInfo.Dimension(1);
actInfo = getActionInfo(env);
numActions = actInfo.Dimension(1);
Create a deterministic transition function approximator based on a deep neural network with two input channels (current observations and actions) and one output channel (predicted next observation).
% Create network layers. statePath = featureInputLayer(numObservations, Name="obsInLyr"); actionPath = featureInputLayer(numActions, Name="actInLyr"); commonPath = [ concatenationLayer(1,2,Name="concat") fullyConnectedLayer(64) reluLayer fullyConnectedLayer(64) reluLayer fullyConnectedLayer(numObservations,Name="nextObsOutLyr")]; % Assemble dlnetwork object and connect layers. trnsNet = dlnetwork(); trnsNet = addLayers(trnsNet,statePath); trnsNet = addLayers(trnsNet,actionPath); trnsNet = addLayers(trnsNet,commonPath); trnsNet = connectLayers(trnsNet,"obsInLyr","concat/in1"); trnsNet = connectLayers(trnsNet,"actInLyr","concat/in2"); % Plot network. plot(trnsNet)
% Initialize network and display the number of weights.
trnsNet = initialize(trnsNet);
summary(trnsNet)
Initialized: true Number of learnables: 4.8k Inputs: 1 'obsInLyr' 4 features 2 'actInLyr' 1 features
% Create transition function approximator object. transitionFcn = rlContinuousDeterministicTransitionFunction(... trnsNet,obsInfo,actInfo,... ObservationInputNames="obsInLyr", ... ActionInputNames="actInLyr", ... NextObservationOutputNames="nextObsOutLyr");
You can define a known reward approximator for your environment using a custom function. Your custom reward approximator must take the observations, actions, and next observations as cell-array inputs and return a scalar reward value. For this example, use the following custom reward function, which computes the reward based on the next observation.
type cartPoleRewardFunction.m
function reward = cartPoleRewardFunction(obs,action,nextObs) % Compute reward value based on the next observation. if iscell(nextObs) nextObs = nextObs{1}; end % Distance at which to fail the episode xThreshold = 2.4; % Reward each time step the cart-pole is balanced rewardForNotFalling = 1; % Penalty when the cart-pole fails to balance penaltyForFalling = -50; x = nextObs(1,:); distReward = 1 - abs(x)/xThreshold; isDone = cartPoleIsDoneFunction(obs,action,nextObs); reward = zeros(size(isDone)); reward(logical(isDone)) = penaltyForFalling; reward(~logical(isDone)) = ... 0.5 * rewardForNotFalling + 0.5 * distReward(~logical(isDone)); end
You can define a known is-done approximator for your environment using a custom function. Your custom is-done function must take the observations, actions, and next observations as cell-array inputs and return a logical termination signal. For this example, use the following custom is-done function, which computes the termination signal based on the next observation.
type cartPoleIsDoneFunction.m
function isDone = cartPoleIsDoneFunction(obs,action,nextObs) % Compute termination signal based on next observation. if iscell(nextObs) nextObs = nextObs{1}; end % Angle at which to fail the episode thetaThresholdRadians = 12 * pi/180; % Distance at which to fail the episode xThreshold = 2.4; x = nextObs(1,:); theta = nextObs(3,:); isDone = abs(x) > xThreshold | abs(theta) > thetaThresholdRadians; end
Create a neural network environment using the transition function object and the custom reward and is-done functions.
env = rlNeuralNetworkEnvironment(obsInfo,actInfo,transitionFcn,...
@cartPoleRewardFunction,@cartPoleIsDoneFunction);
Version History
Introduced in R2022a
See Also
Functions
Objects
MATLAB 命令
您点击的链接对应于以下 MATLAB 命令:
请在 MATLAB 命令行窗口中直接输入以执行命令。Web 浏览器不支持 MATLAB 命令。
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)