Tips on Importing Models from TensorFlow, PyTorch, and ONNX
This topic provides tips on how to overcome common hurdles in importing a model from TensorFlow™, PyTorch®, or ONNX™ as a MATLAB® network. You can read each section of this topic independently. For a high-level overview of the import and export functions in Deep Learning Toolbox™, see Interoperability Between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX.
Import Functions of Deep Learning Toolbox
This table lists the Deep Learning Toolbox import tools. Use these tools to import networks from TensorFlow, PyTorch, and ONNX.
Tip
You can import TensorFlow and PyTorch networks using the Deep Network Designer app. On import, the app shows an import report with details about any issues that require attention.
You must have the relevant support package to use the import apps and functions. If the support package is not installed, then the software provides a download link to the corresponding support package in the Add-On Explorer. A recommended practice is to download the support package to the default location for the version of MATLAB you are running. You can also directly download the support packages from File Exchange.
Detecting Issues
When importing networks from external platforms, you may be required to take action before the network is ready to use. If you import a network using the Deep Network Designer app, then after import the app generates an import report. This report shows issue such as if your network has placeholder layers that you need to complete. Deep Network Designer supports the import of TensorFlow and PyTorch models.
Autogenerated Custom Layers
The Deep Network Designer app and the
importNetworkFromTensorFlow
,importNetworkFromPyTorch
, andimportNetworkFromONNX
functions can automatically generate custom layers, or custom layers with placeholder functions, when you import TensorFlow, PyTorch, or ONNX layers that the software cannot convert into equivalent built-in MATLAB functions or layers.The Deep Network Designer app and the
importNetworkFromTensorFlow
,importNetworkFromPyTorch
, andimportNetworkFromONNX
functions import an external platform layer into MATLAB by trying these steps in order:The software imports the external layer as a built-in MATLAB layer.
The software imports the external layer as a built-in MATLAB function (for TensorFlow and PyTorch only).
The software imports the external layer as a custom layer.
The software imports the external layer as a custom layer with a placeholder function.
For more information about custom layer generation, see the
Algorithms
section of each function: Algorithms (TensorFlow), Algorithms (PyTorch), and Algorithms (ONNX).
Input Dimension Ordering
The dimension ordering of the input data differs between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX. This table compares input dimension ordering between platforms for different input types.
Input Type | Dimension Ordering | |||
---|---|---|---|---|
MATLAB | TensorFlow | PyTorch | ONNX | |
Features | CN | NC | NC | NC |
2-D image | HWCN | NHWC | NCHW | NCHW |
3-D image | HWDCN | NHWDC | NCDHW | NCHWD |
Vector sequence | CSN | NSC | SNC | NSC |
2-D image sequence | HWCSN | NSWHC | NCSHW | NSCHW |
3-D image sequence | HWDCSN | NSWHDC | NCSDHW | NSCHWD |
Variable names in the table:
N — Number of observations
C — Number of features or channels
H — Height of images
W — Width of images
D — Depth of images
S — Sequence length
Data Formats for Prediction with dlnetwork
The Deep Network Designer app and the
importNetworkFromTensorFlow
function import a TensorFlow network as an initialized dlnetwork
object. For an example,
see Import TensorFlow Network and Classify Image. If the network does
not have fixed input size, the software imports the model as an uninitialized
dlnetwork
object without an input layer. For an example about how
to initialize this network, see Import and Initialize TensorFlow Network.
The Deep Network Designer app and the
importNetworkFromPyTorch
function import a PyTorch network as an uninitialized or initialized dlnetwork
object. If the imported network is uninitialized, before you use the network, do one of
the following:
In the Deep Network Designer app, add an input layer. The app initializes the network when you export. Alternatively, add an input layer to the imported network and initialize the network by using the
addInputLayer
function. For an example, see Import Network from PyTorch and Add Input Layer.Define a
dlarray
object with the appropriate data format and use theinitialize
function to initialize the network. For an example, see Import Network from PyTorch and Initialize.
A PyTorch network can be imported as an initialized dlnetwork
object by using the PyTorchInputSizes
name-value argument. For an
example, see Import Network from PyTorch using PyTorchInputSizes.
The importNetworkFromONNX
function imports an ONNX network as an initialized dlnetwork
object. For an example,
see Import ONNX Network and Classify Image.
To predict using a dlnetwork
object, you must convert the input data
to a dlarray
object with the appropriate data format. For an example, see Import TensorFlow Network and Classify Image. Use this table to
choose the right data format for each input type and layer.
Input Type | Input Layer ** | Input Format * |
---|---|---|
Features | featureInputLayer | CB |
2-D image | imageInputLayer | SSCB |
3-D image | image3dInputLayer | SSSCB |
Vector sequence | sequenceInputLayer | CBT |
2-D image sequence | sequenceInputLayer | SSCBT |
3-D image sequence | sequenceInputLayer | SSSCBT |
* In Deep Learning Toolbox, each data format must be one of these labels:
S
— SpatialC
— ChannelB
— Batch observationsT
— Time or sequenceU
— Unspecified
** A dlnetwork
object does not require an input layer. The network
can infer the input layer type from the input data format.
For more information on data formats, see dlarray
.
Input Data Preprocessing
To use a pretrained network for prediction or transfer learning on new images, you must preprocess your images in the as same way the images that were used to train the imported model. The most common preprocessing steps are resizing images, subtracting image average values, and converting the images from BGR format to RGB format.
For more information about preprocessing images for training and prediction, see Preprocess Images for Deep Learning.
See Also
Deep Network
Designer | importNetworkFromONNX
| importNetworkFromPyTorch
| importNetworkFromTensorFlow
| dlarray
Related Topics
- Interoperability Between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX
- Pretrained Deep Neural Networks