Main Content

Tips on Importing Models from TensorFlow, PyTorch, and ONNX

This topic provides tips on how to overcome common hurdles in importing a model from TensorFlow™, PyTorch®, or ONNX™ as a MATLAB® network or layer graph. You can read each section of this topic independently. For a high-level overview of the import and export functions in Deep Learning Toolbox™, see Interoperability Between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX.

Import Functions of Deep Learning Toolbox

This table lists the Deep Learning Toolbox import functions. Use these functions to import networks or layer graphs from TensorFlow, PyTorch, and ONNX.

You must have the relevant support package to run these import functions. If the support package is not installed, each function provides a download link to the corresponding support package in the Add-On Explorer. A recommended practice is to download the support package to the default location for the version of MATLAB you are running. You can also directly download the support packages from File Exchange.

Recommended Functions to Import TensorFlow Models

The Deep Learning Toolbox Converter for TensorFlow Models support package offers these functions:

  • importTensorFlowNetwork and importKerasNetwork — Import a TensorFlow model as a network.

  • importTensorFlowLayers and importKerasLayers — Import a TensorFlow model as a layer graph.

Note

The importTensorFlowNetwork and importTensorFlowLayers functions are recommended over the importKerasNetwork and importKerasLayers functions.

This table compares the Deep Learning Toolbox Converter for TensorFlow Models functions. The comparison highlights the reasons that the importTensorFlowNetwork and importTensorFlowLayers functions are recommended over the importKerasNetwork and importKerasLayers functions.

FeaturesimportTensorFlowNetwork and importTensorFlowLayersimportKerasNetwork and importKerasLayers
Automatically generates custom layersYesNo
Supports TensorFlow 2YesLimited
Supports SavedModel formatYesNo
Can import network as dlnetwork (or LayerGraph compatible with dlnetwork)YesNo

For more information on the advantages of migrating from TensorFlow 1 to TensorFlow 2, see Migrate from TensorFlow 1.x to TensorFlow 2. For more information on the TensorFlow versions that the import functions support, see Limitations (importTensorFlowNetwork and importTensorFlowLayers) and Limitations (importKerasNetwork and importKerasLayers).

To import a TensorFlow model that is in the HDF5 format, instead of using importKerasNetwork to import the model as a Deep Learning Toolbox network, convert the TensorFlow model to the SavedModel format and use the importTensorFlowNetwork function.

Autogenerated Custom Layers

The importTensorFlowNetwork, importTensorFlowLayers, importONNXNetwork, importONNXLayers, and importNetworkFromPyTorch functions save the automatically generated custom layers to a package in the current folder. For more information on the custom layers package, see the PackageName name-value argument of each function.

Placeholder Layers

The importTensorFlowLayers and importONNXLayers functions insert placeholder layers in the place of TensorFlow layers or ONNX operators when these conditions apply:

If these conditions apply, the importTensorFlowNetwork and importONNXNetwork functions return an error. These flowcharts describe these workflows.

Flow chart showing

To find the names and indices of the placeholder layers in the layer graph, use the findPlaceholderLayers function. You can then replace a placeholder layer with a built-in MATLAB layer, custom layer, or functionLayerobject. For more information about custom layers, see Define Custom Deep Learning Layers. For an example with a functionLayer object, see Replace Unsupported Keras Layer with Function Layer. To replace a layer, use replaceLayer. For an example, see Import ONNX Model as Layer Graph with Placeholder Layers.

The importNetworkFromPyTorch function generates a custom layer with a placeholder function instead of a placeholder layer. For more information, see Placeholder Functions.

Input Dimension Ordering

The dimension ordering of the input data differs between Deep Learning Toolbox, TensorFlow, and ONNX. This table compares input dimension ordering between platforms for different input types.

Input TypeDimension Ordering
MATLABTensorFlowPyTorchONNX
FeaturesCNNCNCNC
2-D imageHWCNNHWCNCHWNCHW
3-D imageHWDCNNHWDCNCDHWNCHWD
Vector sequenceCSNNSCSNCNSC
2-D image sequenceHWCSNNSWHCNCSHWNSCHW
3-D image sequenceHWDCSNNSWHDCNCSDHWNSCHWD

Variable names in the table:

  • N — Number of observations

  • C — Number of features or channels

  • H — Height of images

  • W — Width of images

  • D — Depth of images

  • S — Sequence length

Data Formats for Prediction with dlnetwork

The importTensorFlowNetwork and importONNXNetwork functions can import a TensorFlow or ONNX model as a DAGNetwork or dlnetwork object. Specify the type of imported network by setting the TargetNetwork name-value argument. For more details, see TargetNetwork for importTensorFlowNetwork and TargetNetwork for importONNXNetwork.

The importNetworkFromPyTorch function imports a PyTorch model as an uninitialized dlnetwork object. Before you use the network, do one of the following:

To predict using a dlnetwork object, you must convert the input data to a dlarray object with the appropriate data format. For an example, see Import TensorFlow Network as dlnetwork to Classify Image. Use this table to choose the right data format for each input type and layer.

Input TypeInput Layer **Input Format *
FeaturesfeatureInputLayerCB
2-D imageimageInputLayerSSCB
3-D imageimage3dInputLayerSSCB
Vector sequencesequenceInputLayerCBT
2-D image sequencesequenceInputLayerSSCBT
3-D image sequencesequenceInputLayerSSSCBT

* In Deep Learning Toolbox, each data format must be one of these labels:

  • S — Spatial

  • C — Channel

  • B — Batch observations

  • T — Time or sequence

  • U — Unspecified

** A dlnetwork object does not require an input layer. The network can infer the input layer type from the input data format.

For more information on data formats, see dlarray.

Input Data Preprocessing

Preprocessing data is a common first step in the deep learning workflow to prepare data in a format that the network can accept. You must preprocess the input data in the same way as the training data.

The input layer of the pretrained deep learning networks available in Deep Learning Toolbox performs some of the input data preprocessing. For example, the input layer of the pretrained mobilenetv2 normalizes the image input data. Display the Normalization property of the network input layer.

net = mobilenetv2;
net.Layers(1).Normalization
ans =

    'zscore'

Networks that you import from TensorFlow or ONNX might not have built-in preprocessing in the input layer. For example, the input layer of the imported MobileNetV2 from TensorFlow does not normalize the input image. Import MobileNetV2 and display the Normalization property of the network input layer.

net = importTensorFlowNetwork("MobileNetV2", ...
    OutputLayerType="classification");
net.Layers(1).Normalization
ans =

    'none'

Often, open-source repositories provide information about the required input data preprocessing. For example, see tf.keras.applications.mobilenet_v2.preprocess_input and ShuffleNet in ONNX Model Zoo. To learn more about how to preprocess images and other types of data in Deep Learning Toolbox, see Preprocess Images for Deep Learning and Deep Learning Data Preprocessing.

See Also

| | | | |

Related Topics

External Websites