Import Neural Network Models Using ONNX
To create function approximators for reinforcement learning, you can import pre-trained deep neural networks or deep neural network layer architectures using the Deep Learning Toolbox™ network import functionality. You can import:
Open Neural Network Exchange (ONNX™) models, which require the Deep Learning Toolbox Converter for ONNX Model Format support package software. For more information, see
importNetworkFromONNX
.TensorFlow™-Keras networks, which require Deep Learning Toolbox Converter for TensorFlow Models support package software. For more information, see
importNetworkFromTensorFlow
.
After you import a deep neural network, you can create an actor or critic object, such as
rlValueFunction
or
rlDiscreteCategoricalActor
.
When you import deep neural network architectures, consider the following.
The dimensions of the input and output layers of the imported network architecture must match the dimensions of the environment action, observation, or reward channels that need to be connected to those layers.
After importing the network architecture, it is best practice to specify the names of the input and output layers to be connected to the corresponding environment action and observation channels.
For more information on the deep neural network architectures supported for reinforcement learning, see Create Policies and Value Functions.
Import Actor and Critic for Image Observation Application
As an example, assume that you have an environment with a 50-by-50 grayscale image observation signal and a continuous action space. To train a policy gradient (PG) agent, you require the following function approximators, both of which must have a single 50-by-50 image input observation layer and a single scalar output value.
Actor — Selects an action based on the current observation.
Critic — Estimates the expected discounted cumulative long-term reward based on the current observation.
Also, assume that you have the following network architectures to import:
A deep neural network architecture for the critic with a 50-by-50 image input layer and a scalar output layer, which is saved in the ONNX format (
criticNetwork.onnx
).A deep neural network architecture for the actor with a 50-by-50 image input layer and a scalar output layer, which is saved in the ONNX format (
actorNetwork.onnx
).
To import the critic and actor networks, use the
importNetworkFromONNX
function.
criticNetwork = importNetworkFromONNX("criticNetwork.onnx"); actorNetwork = importNetworkFromONNX("actorNetwork.onnx");
After you import the network, if you already have an appropriate agent for your
environment you can use getActor
and
getCritic
to
extract the actor and critic function approximators for the agent, then setModel
to set the
imported networks a the approximation models of the actor and critic, and then setActor
and
setCritic
to set
the actor and critic with the imported network into your agent.
Alternatively, create new actor and critic function approximators that use the imported networks. To do so, first obtain the observation and action specifications from the environment.
obsInfo = getObservationInfo(env); actInfo = getActionInfo(env);
Create the critic. PG agents use an rlValueFunction
approximator.
critic = rlValueFunction(criticNetwork,obsInfo);
If your critic has more than one input channel (for example because your environment has more than one output channel or because you are using a Q-value function critic, which also needs an action input), it is good practice to specify the names of the input layer that need to be connected, in sequential order, with each critic input channel. For an example, see Train DDPG Agent to Swing Up and Balance Pendulum with Image Observation.
Create the actor. PG agents use an rlContinuousDeterministicActor
approximator.
actor = rlContinuousDeterministicActor(actorNetwork,obsInfo,actInfo);
As for the critic, if your actor has more than one input channel (because your environment has more than one output channel), it is good practice to specify the name of the input layer that needs to be connected with each actor input channel.
To verify that your actor and critic work properly, use getAction
and
getValue
to return
an action (for the actor) and the value (for the critic) corresponding to a random
observation, using the current network weights.
After you have the actor and critic with the imported networks, you can then:
Create an agent using this actor and critic. For more information, see Reinforcement Learning Agents.
Set the actor and critic in an existing agent using
setActor
andsetCritic
, respectively.
See Also
Functions
importNetworkFromONNX
|importNetworkFromTensorFlow
|importCaffeLayers
|getActionInfo
|getObservationInfo
|getAction
|getValue
|setActor
|setCritic
|setModel