Main Content

Import Neural Network Models Using ONNX

To create function approximators for reinforcement learning, you can import pre-trained deep neural networks or deep neural network layer architectures using the Deep Learning Toolbox™ network import functionality. You can import:

  • Open Neural Network Exchange (ONNX™) models, which require the Deep Learning Toolbox Converter for ONNX Model Format support package software. For more information, see importNetworkFromONNX.

  • TensorFlow™-Keras networks, which require Deep Learning Toolbox Converter for TensorFlow Models support package software. For more information, see importNetworkFromTensorFlow.

After you import a deep neural network, you can create an actor or critic object, such as rlValueFunction or rlDiscreteCategoricalActor.

When you import deep neural network architectures, consider the following.

  • The dimensions of the input and output layers of the imported network architecture must match the dimensions of the environment action, observation, or reward channels that need to be connected to those layers.

  • After importing the network architecture, it is best practice to specify the names of the input and output layers to be connected to the corresponding environment action and observation channels.

For more information on the deep neural network architectures supported for reinforcement learning, see Create Policies and Value Functions.

Import Actor and Critic for Image Observation Application

As an example, assume that you have an environment with a 50-by-50 grayscale image observation signal and a continuous action space. To train a policy gradient (PG) agent, you require the following function approximators, both of which must have a single 50-by-50 image input observation layer and a single scalar output value.

  • Actor — Selects an action based on the current observation.

  • Critic — Estimates the expected discounted cumulative long-term reward based on the current observation.

Also, assume that you have the following network architectures to import:

  • A deep neural network architecture for the critic with a 50-by-50 image input layer and a scalar output layer, which is saved in the ONNX format (criticNetwork.onnx).

  • A deep neural network architecture for the actor with a 50-by-50 image input layer and a scalar output layer, which is saved in the ONNX format (actorNetwork.onnx).

To import the critic and actor networks, use the importNetworkFromONNX function.

criticNetwork = importNetworkFromONNX("criticNetwork.onnx");
actorNetwork = importNetworkFromONNX("actorNetwork.onnx");

After you import the network, if you already have an appropriate agent for your environment you can use getActor and getCritic to extract the actor and critic function approximators for the agent, then setModel to set the imported networks a the approximation models of the actor and critic, and then setActor and setCritic to set the actor and critic with the imported network into your agent.

Alternatively, create new actor and critic function approximators that use the imported networks. To do so, first obtain the observation and action specifications from the environment.

obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

Create the critic. PG agents use an rlValueFunction approximator.

critic = rlValueFunction(criticNetwork,obsInfo);

If your critic has more than one input channel (for example because your environment has more than one output channel or because you are using a Q-value function critic, which also needs an action input), it is good practice to specify the names of the input layer that need to be connected, in sequential order, with each critic input channel. For an example, see Train DDPG Agent to Swing Up and Balance Pendulum with Image Observation.

Create the actor. PG agents use an rlContinuousDeterministicActor approximator.

actor = rlContinuousDeterministicActor(actorNetwork,obsInfo,actInfo);

As for the critic, if your actor has more than one input channel (because your environment has more than one output channel), it is good practice to specify the name of the input layer that needs to be connected with each actor input channel.

To verify that your actor and critic work properly, use getAction and getValue to return an action (for the actor) and the value (for the critic) corresponding to a random observation, using the current network weights.

After you have the actor and critic with the imported networks, you can then:

See Also

Functions

Objects

Related Examples

More About