Main Content

本页的翻译已过时。点击此处可查看最新英文版本。

预训练的深度神经网络

您可以采用预训练的图像分类网络,它已学会从自然图像中提取功能强大且包含丰富信息的特征,并以此作为学习新任务的起点。大多数预训练网络是基于 ImageNet 数据库 [1] 的子集进行训练的,该数据库用于 ImageNet Large-Scale Visual Recognition Challenge (ILSVRC) [2] 中。这些网络已经对超过一百万个图像进行了训练,可以将图像分为 1000 个对象类别,例如键盘、咖啡杯、铅笔和多种动物。通常来说,使用预训练网络进行迁移学习比从头开始训练网络更快更容易。

您可以将之前训练过的网络用于以下任务:

目的说明
分类

将预训练网络直接应用于分类问题。要对新图像进行分类,请使用 classify。有关如何使用预训练网络进行分类的示例,请参阅使用 GoogLeNet 对图像进行分类

特征提取

通过使用层激活作为特征,使用预训练网络作为特征提取器。您可以使用这些激活作为特征来训练另一个机器学习模型,例如支持向量机 (SVM)。有关详细信息,请参阅特征提取。有关示例,请参阅使用预训练网络提取图像特征

迁移学习

从基于大型数据集训练的网络中提取层,并基于新数据集进行微调。有关详细信息,请参阅迁移学习。有关简单的示例,请参阅迁移学习快速入门。要尝试更多预训练网络,请参阅训练深度学习网络以对新图像进行分类

比较预训练网络

当选择适用于您的问题的网络时,预训练网络具有不同的重要特征。最重要的特征是网络的准确度、速度和规模。选择网络时通常需要在这些特征之间进行权衡。使用下图比较 ImageNet 验证准确度和使用网络进行预测所需的时间。

提示

要开始迁移学习,请尝试选择一个更快的网络,例如 SqueezeNet 或 GoogLeNet。然后,您可以快速迭代并尝试不同设置,如数据预处理步骤和训练选项。一旦您感觉到哪些设置运行良好,请尝试更准确的网络,例如 Inception-v3 或 ResNet,看看这是否能改进您的结果。

Comparison of the accuracy and relative prediction time of the pretrained networks. As the accuracy of the pretrained networks increases, so does the relative prediction time.

注意

上图仅显示不同网络的相对速度。准确的预测和训练迭代时间取决于您使用的硬件和小批量大小。

理想的网络具有高准确度并且速度很快。该图显示的是使用现代 GPU (NVIDIA® Tesla® P100) 和大小为 128 的小批量时分类准确度对预测时间的结果。预测时间是相对于最快的网络来测量的。每个标记的面积与网络在磁盘上的大小成正比。

ImageNet 验证集上的分类准确度是衡量在 ImageNet 上训练的网络准确度的最常见方法。如果您的网络在 ImageNet 上准确,则当您使用迁移学习或特征提取将网络应用于其他自然图像数据集时,您的网络通常也是准确的。这种泛化之所以可行,是因为网络已学会从自然图像中提取强大的信息特征,这些特征可以泛化到其他类似的数据集。但是,在 ImageNet 上的高准确度并不能始终直接迁移到其他任务,因此最好尝试多个网络。

如果您要使用受限制的硬件执行预测或通过 Internet 分发网络,则还要考虑网络在磁盘上和内存中的大小。

网络准确度

可以使用多种方法来计算基于 ImageNet 验证集的分类准确度,不同数据源使用不同的方法。有时使用包含多个模型的集合,有时使用多次裁剪对每个图像进行多次计算。有时会引用 top-5 准确度,而不是标准 (top-1) 准确度。由于这些差异,通常无法直接比较不同数据源的准确度。Deep Learning Toolbox™ 中预训练网络的准确度是使用单一模型和单一中心图像裁剪的标准 (top-1) 准确度。

加载预训练网络

要加载 SqueezeNet 网络,请在命令行中键入 squeezenet

net = squeezenet;

对于其他网络,请使用 googlenet 等函数来获取链接,以便从附加功能资源管理器下载预训练网络。

下表列出了基于 ImageNet 训练的可用预训练网络以及这些网络的一些属性。网络深度定义为从输入层到输出层的路径中顺序卷积层或全连接层的最大数量。所有网络的输入均为 RGB 图像。

网络深度大小参数(单位为百万)图像输入大小
squeezenet18

5.2 MB

1.24

227×227

googlenet22

27 MB

7.0

224×224

inceptionv348

89 MB

23.9

299×299

densenet201201

77 MB

20.0

224×224

mobilenetv253

13 MB

3.5

224×224

resnet1818

44 MB

11.7

224×224

resnet5050

96 MB

25.6

224×224

resnet101101

167 MB

44.6

224×224

xception71

85 MB

22.9299×299
inceptionresnetv2164

209 MB

55.9

299×299

shufflenet505.4 MB1.4224×224
nasnetmobile*20 MB 5.3224×224
nasnetlarge*332 MB88.9331×331
darknet191978 MB20.8256×256
darknet5353155 MB41.6256×256
efficientnetb08220 MB5.3

224×224

alexnet8

227 MB

61.0

227×227

vgg1616

515 MB

138

224×224

vgg1919

535 MB

144

224×224

*NASNet-Mobile 和 NASNet-Large 网络不是由模块的线性序列构成的。

基于 Places365 训练的 GoogLeNet

标准 GoogLeNet 网络是基于 ImageNet 数据集进行训练的,但您也可以加载基于 Places365 数据集训练的网络 [3] [4]。基于 Places365 训练的网络将图像分为 365 个不同位置类别,例如田野、公园、跑道和大厅。要加载基于 Places365 数据集训练的预训练 GoogLeNet 网络,请使用 googlenet('Weights','places365')。在执行迁移学习以执行新任务时,最常见的方法是使用基于 ImageNet 预训练的网络。如果新任务类似于场景分类,则使用基于 Places365 训练的网络可以提供更高的准确度。

可视化预训练网络

您可以使用深度网络设计器加载和可视化预训练网络。

deepNetworkDesigner(squeezenet)

Deep Network Designer displaying a pretrained SqueezeNet network

要查看和编辑层属性,请选择一个层。有关层属性的信息,请点击层名称旁边的帮助图标。

Cross channel normalization layer selected in Deep Network Designer. The PROPERTIES pane shows the properties of the layer.

通过点击 New,在深度网络设计器中浏览其他预训练网络。

Deep Network Designer start page showing available pretrained networks

如果需要下载一个网络,请在所需的网络上暂停,然后点击安装以打开附加功能资源管理器。

特征提取

特征提取可以简单快捷地利用深度学习的强大功能,而无需投入时间和精力来训练完整网络。由于它只需遍历一次训练图像,因此如果您没有 GPU,特征提取会特别有用。您使用预训练网络提取学习到的图像特征,然后使用这些特征来训练分类器,例如使用 fitcsvm (Statistics and Machine Learning Toolbox) 的支持向量机。

当您的新数据集很小时,请尝试使用特征提取。由于您仅基于提取的特征来训练简单的分类器,因此训练速度很快。由于几乎没有数据可供学习,因此微调网络的更深层也不太可能提高准确度。

  • 如果您的数据与原始数据非常相似,则在网络的更深层提取的更具体的特征可能对新任务有用。

  • 如果您的数据与原始数据相差很大,则在网络的更深层提取的特征可能对您的任务用处不大。请尝试基于从较浅网络层提取的更一般特征来训练最终的分类器。如果新数据集很大,则您也可以尝试从头开始训练网络。

ResNet 通常是合适的特征提取器。有关如何使用预训练网络进行特征提取的示例,请参阅使用预训练网络提取图像特征

迁移学习

您可以通过基于新数据集对网络进行训练来微调网络中的更深层,并以该预训练网络为起点。通过迁移学习来微调网络通常比构建和训练新网络更快更容易。网络已学习到一系列丰富的图像特征,但当您微调网络时,它可以学习特定于您的新数据集的特征。如果您有超大型数据集,则迁移学习可能不会比从头开始训练更快。

提示

微调网络通常能达到最高的准确度。对于非常小的数据集(每个类不到 20 个图像),请尝试使用特征提取。

与简单的特征提取相比,微调网络会更慢,需要完成的工作更多,但由于网络可以学习提取不同的特征集,最终的网络通常更准确。只要新数据集不是特别小,微调通常比特征提取效果更好,因为微调时网络有数据可供学习新特征。有关如何执行迁移学习的示例,请参阅使用深度网络设计器进行迁移学习训练深度学习网络以对新图像进行分类

Transfer learning workflow

导入和导出网络

您可以从 TensorFlow®-Keras、Caffe 和 ONNX™(开放式神经网络交换)模型格式导入网络和网络架构。您还可以将经过训练的网络导出为 ONNX 模型格式。

从 Keras 导入

使用 importKerasNetwork 从 TensorFlow-Keras 导入预训练网络。您可以从同一个 HDF5 (.h5) 文件或单独的 HDF5 和 JSON (.json) 文件导入网络和权重。有关详细信息,请参阅 importKerasNetwork

使用 importKerasLayers 从 TensorFlow-Keras 导入网络架构。您可以导入网络架构,使用或不使用权重均可。您可以从同一个 HDF5 (.h5) 文件或单独的 HDF5 和 JSON (.json) 文件导入网络架构和权重。有关详细信息,请参阅 importKerasLayers

从 Caffe 导入

使用 importCaffeNetwork 函数从 Caffe 中导入预训练网络。Caffe Model Zoo 中提供许多可用的预训练网络 [5]。下载所需的 .prototxt.caffemodel 文件,并使用 importCaffeNetwork 将预训练网络导入 MATLAB®。有关详细信息,请参阅 importCaffeNetwork

您可以导入 Caffe 网络的网络架构。下载所需的 .prototxt 文件,并使用 importCaffeLayers 将网络层导入 MATLAB。有关详细信息,请参阅 importCaffeLayers

ONNX 中导入和导出

通过使用 ONNX 作为中间格式,您可以与支持 ONNX 模型导出或导入的其他深度学习框架进行互操作,这些框架包括 TensorFlow、PyTorch、Caffe2、Microsoft® Cognitive Toolkit (CNTK)、Core ML 和 Apache MXNet™。

使用 exportONNXNetwork 函数将经过训练的 Deep Learning Toolbox 网络导出为 ONNX 模型格式。然后,您可以将 ONNX 模型导入支持 ONXX 模型导入的其他深度学习框架中。

使用 importONNXNetwork 从 ONNX 导入预训练网络,并使用 importONNXLayers 导入带或不带权重的网络架构。

音频应用的预训练网络

通过将 Deep Learning Toolbox 与 Audio Toolbox™ 结合使用,将预训练网络用于音频和语音处理应用。

Audio Toolbox 提供预训练的 VGGish 和 YAMNet 网络。使用 vggish (Audio Toolbox)yamnet (Audio Toolbox) 函数直接与预训练网络交互。classifySound (Audio Toolbox) 函数为 YAMNet 执行所需的预处理和后处理,以便您可以定位声音并将其划分到 521 个类别中的一个。您可以使用 yamnetGraph (Audio Toolbox) 函数探查 YAMNet 本体。vggishFeatures (Audio Toolbox) 函数为 VGGish 执行必要的预处理和后处理,以便您可以提取特征嵌入,以输入到机器学习和深度学习系统。有关在音频应用中使用深度学习的详细信息,请参阅Introduction to Deep Learning for Audio Applications (Audio Toolbox)

使用 VGGish 和 YAMNet 执行迁移学习和特征提取。有关示例,请参阅Transfer Learning with Pretrained Audio Networks (Audio Toolbox)

GitHub 上的预训练模型

要找到最新预训练模型和示例,请参阅 MATLAB Deep Learning (GitHub)

例如:

参考

[1] ImageNet. http://www.image-net.org

[2] Russakovsky, O., Deng, J., Su, H., et al. “ImageNet Large Scale Visual Recognition Challenge.” International Journal of Computer Vision (IJCV). Vol 115, Issue 3, 2015, pp. 211–252

[3] Zhou, Bolei, Aditya Khosla, Agata Lapedriza, Antonio Torralba, and Aude Oliva. "Places: An image database for deep scene understanding." arXiv preprint arXiv:1610.02055 (2016).

[4] Places. http://places2.csail.mit.edu/

[5] Caffe Model Zoo. http://caffe.berkeleyvision.org/model_zoo.html

另请参阅

| | | | | | | | | | | | | | | | | | | | | | | | |

相关主题