Main Content

Tips on Importing Models from TensorFlow, PyTorch, and ONNX

This topic provides tips on how to overcome common hurdles in importing a model from TensorFlow™, PyTorch®, or ONNX™ as a MATLAB® network. You can read each section of this topic independently. For a high-level overview of the import and export functions in Deep Learning Toolbox™, see Interoperability Between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX.

Import Functions of Deep Learning Toolbox

This table lists the Deep Learning Toolbox import tools. Use these tools to import networks from TensorFlow, PyTorch, and ONNX.

Tip

You can import TensorFlow and PyTorch networks using the Deep Network Designer app. On import, the app shows an import report with details about any issues that require attention.

You must have the relevant support package to use the import apps and functions. If the support package is not installed, then the software provides a download link to the corresponding support package in the Add-On Explorer. A recommended practice is to download the support package to the default location for the version of MATLAB you are running. You can also directly download the support packages from File Exchange.

Detecting Issues

When importing networks from external platforms, you may be required to take action before the network is ready to use. If you import a network using the Deep Network Designer app, then after import the app generates an import report. This report shows issue such as if your network has placeholder layers that you need to complete. Deep Network Designer supports the import of TensorFlow and PyTorch models.

Deep Network Designer import report. The report highlights that there are two warnings for the import.

Autogenerated Custom Layers

  • The Deep Network Designer app and the importNetworkFromTensorFlow, importNetworkFromPyTorch, and importNetworkFromONNX functions can automatically generate custom layers, or custom layers with placeholder functions, when you import TensorFlow, PyTorch, or ONNX layers that the software cannot convert into equivalent built-in MATLAB functions or layers.

  • The Deep Network Designer app and the importNetworkFromTensorFlow, importNetworkFromPyTorch, and importNetworkFromONNX functions import an external platform layer into MATLAB by trying these steps in order:

    1. The software imports the external layer as a built-in MATLAB layer.

    2. The software imports the external layer as a built-in MATLAB function (for TensorFlow and PyTorch only).

    3. The software imports the external layer as a custom layer.

    4. The software imports the external layer as a custom layer with a placeholder function.

For more information about custom layer generation, see the Algorithms section of each function: Algorithms (TensorFlow), Algorithms (PyTorch), and Algorithms (ONNX).

Input Dimension Ordering

The dimension ordering of the input data differs between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX. This table compares input dimension ordering between platforms for different input types.

Input TypeDimension Ordering
MATLABTensorFlowPyTorchONNX
FeaturesCNNCNCNC
2-D imageHWCNNHWCNCHWNCHW
3-D imageHWDCNNHWDCNCDHWNCHWD
Vector sequenceCSNNSCSNCNSC
2-D image sequenceHWCSNNSWHCNCSHWNSCHW
3-D image sequenceHWDCSNNSWHDCNCSDHWNSCHWD

Variable names in the table:

  • N — Number of observations

  • C — Number of features or channels

  • H — Height of images

  • W — Width of images

  • D — Depth of images

  • S — Sequence length

Data Formats for Prediction with dlnetwork

The Deep Network Designer app and the importNetworkFromTensorFlow function import a TensorFlow network as an initialized dlnetwork object. For an example, see Import TensorFlow Network and Classify Image. If the network does not have fixed input size, the software imports the model as an uninitialized dlnetwork object without an input layer. For an example about how to initialize this network, see Import and Initialize TensorFlow Network.

The Deep Network Designer app and the importNetworkFromPyTorch function import a PyTorch network as an uninitialized or initialized dlnetwork object. If the imported network is uninitialized, before you use the network, do one of the following:

A PyTorch network can be imported as an initialized dlnetwork object by using the PyTorchInputSizes name-value argument. For an example, see Import Network from PyTorch using PyTorchInputSizes.

The importNetworkFromONNX function imports an ONNX network as an initialized dlnetwork object. For an example, see Import ONNX Network and Classify Image.

To predict using a dlnetwork object, you must convert the input data to a dlarray object with the appropriate data format. For an example, see Import TensorFlow Network and Classify Image. Use this table to choose the right data format for each input type and layer.

Input TypeInput Layer **Input Format *
FeaturesfeatureInputLayerCB
2-D imageimageInputLayerSSCB
3-D imageimage3dInputLayerSSSCB
Vector sequencesequenceInputLayerCBT
2-D image sequencesequenceInputLayerSSCBT
3-D image sequencesequenceInputLayerSSSCBT

* In Deep Learning Toolbox, each data format must be one of these labels:

  • S — Spatial

  • C — Channel

  • B — Batch observations

  • T — Time or sequence

  • U — Unspecified

** A dlnetwork object does not require an input layer. The network can infer the input layer type from the input data format.

For more information on data formats, see dlarray.

Input Data Preprocessing

To use a pretrained network for prediction or transfer learning on new images, you must preprocess your images in the as same way the images that were used to train the imported model. The most common preprocessing steps are resizing images, subtracting image average values, and converting the images from BGR format to RGB format.

  • To resize images, use imresize. For example, imresize(image,[227 227 3]).

  • To convert images from RGB to BGR format, use flip. For example, flip(image,3).

For more information about preprocessing images for training and prediction, see Preprocess Images for Deep Learning.

See Also

| | | |

Related Topics

External Websites