加载预训练网络以用于代码生成
您可以为经过预训练的卷积神经网络 (CNN) 生成代码。要向代码生成器提供网络,请从经过训练的网络中加载一个 SeriesNetwork
(Deep Learning Toolbox)、DAGNetwork
(Deep Learning Toolbox)、yolov2ObjectDetector
(Computer Vision Toolbox)、ssdObjectDetector
(Computer Vision Toolbox) 或 dlnetwork
(Deep Learning Toolbox) 对象。
使用 coder.loadDeepLearningNetwork
加载网络
您可以使用 coder.loadDeepLearningNetwork
从任何支持代码生成的网络加载网络对象。您可以从 MAT 文件中指定网络。MAT 文件只能包含要加载的网络。
例如,假设您使用 trainNetwork
(Deep Learning Toolbox) 函数创建了一个名为 myNet
的经过训练的网络对象。然后,您可以通过输入 save
来保存工作区。这将创建一个名为 matlab.mat
的文件,其中包含网络对象。要加载网络对象 myNet
,请输入:
net = coder.loadDeepLearningNetwork('matlab.mat');
您还可以通过以下方式来指定网络:提供不接受输入参数并返回预训练的 SeriesNetwork
、DAGNetwork
、yolov2ObjectDetector
或 ssdObjectDetector
对象的函数的名称,例如:
alexnet
(Deep Learning Toolbox)densenet201
(Deep Learning Toolbox)googlenet
(Deep Learning Toolbox)inceptionv3
(Deep Learning Toolbox)mobilenetv2
(Deep Learning Toolbox)resnet18
(Deep Learning Toolbox)resnet50
(Deep Learning Toolbox)resnet101
(Deep Learning Toolbox)squeezenet
(Deep Learning Toolbox)vgg16
(Deep Learning Toolbox)vgg19
(Deep Learning Toolbox)xception
(Deep Learning Toolbox)
例如,通过输入以下命令加载网络对象:
net = coder.loadDeepLearningNetwork('googlenet');
上述列表中的 Deep Learning Toolbox™ 函数要求您安装适用于这些函数的支持包。请参阅预训练的深度神经网络 (Deep Learning Toolbox)。
为代码生成指定网络对象
如果您使用 codegen
或 App 生成代码,请使用 coder.loadDeepLearningNetwork
将网络对象加载到您的入口函数内。例如:
function out = myNet_predict(in) %#codegen persistent mynet; if isempty(mynet) mynet = coder.loadDeepLearningNetwork('matlab.mat'); end out = predict(mynet,in);
对于可用作支持包函数(如 alexnet
、inceptionv3
、googlenet
和 resnet
)的预训练网络,您可以直接指定支持包函数,例如,通过编写 mynet = googlenet
。
接下来,为入口函数生成代码。例如:
cfg = coder.config('mex'); cfg.TargetLang = 'C++'; cfg.DeepLearningConfig = coder.DeepLearningConfig('mkldnn'); codegen -args {ones(224,224,3,'single')} -config cfg myNet_predict
为代码生成指定 dlnetwork
对象
假设在 MAT 文件 mynet.mat
中有一个预训练的 dlnetwork
网络对象。要预测此网络的响应,请在 MATLAB® 中创建一个入口函数,如以下代码所示。
function a = myDLNet_predict(in) dlIn = dlarray(in, 'SSC'); persistent dlnet; if isempty(dlnet) dlnet = coder.loadDeepLearningNetwork('mynet.mat'); end dlA = predict(dlnet, dlIn); a = extractdata(dlA); end
在此示例中,myDLNet_predict
的输入和输出属于更简单的数据类型,并且 dlarray
对象是在该函数中创建的。dlarray
对象的 extractdata
(Deep Learning Toolbox) 方法在 dlarray
dlA
中返回数据作为 myDLNet_predict
的输出。输出 a
与 dlA
中的基础数据类型具有相同的数据类型。这种入口函数设计具有以下优点:
更容易与独立的代码生成工作流(如静态、动态库或可执行文件)集成。
extractdata
函数输出的数据格式在 MATLAB 环境和生成代码中具有相同的顺序 ('SCBTU'
)。可提高 MEX 工作流的性能。
可使用 MATLAB Function 模块简化 Simulink® 工作流,因为 Simulink 不内生支持
dlarray
对象。
接下来,为入口函数生成代码。例如:
cfg = coder.config('lib'); cfg.TargetLang = 'C++'; cfg.DeepLearningConfig = coder.DeepLearningConfig('mkldnn'); codegen -args {ones(224,224,3,'single')} -config cfg myDLNet_predict
另请参阅
函数
codegen
|trainNetwork
(Deep Learning Toolbox) |coder.loadDeepLearningNetwork
对象
SeriesNetwork
(Deep Learning Toolbox) |DAGNetwork
(Deep Learning Toolbox) |yolov2ObjectDetector
(Computer Vision Toolbox) |ssdObjectDetector
(Computer Vision Toolbox) |dlarray
(Deep Learning Toolbox) |dlnetwork
(Deep Learning Toolbox)