Main Content

本页的翻译已过时。点击此处可查看最新英文版本。

使用预训练网络提取图像特征

此示例说明如何从预训练的卷积神经网络中提取已学习的图像特征,并使用这些特征来训练图像分类器。特征提取是使用预训练深度网络的表征能力的最简单最快捷的方式。例如,您可以使用 fitcecoc(Statistics and Machine Learning Toolbox™) 基于提取的特征来训练支持向量机 (SVM)。由于特征提取只需要遍历一次数据,因此如果没有 GPU 来加速网络训练,则不妨从特征提取开始。

加载数据

解压缩示例图像并加载这些图像作为图像数据存储。imageDatastore 根据文件夹名称自动标注图像,并将数据存储为 ImageDatastore 对象。通过图像数据存储可以存储大图像数据,包括无法放入内存的数据。将数据拆分,其中 70% 用作训练数据,30% 用作测试数据。

unzip('MerchData.zip');
imds = imageDatastore('MerchData','IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain,imdsTest] = splitEachLabel(imds,0.7,'randomized');

在这个非常小的数据集中,现在有 55 个训练图像和 20 个验证图像。显示一些示例图像。

numTrainImages = numel(imdsTrain.Labels);
idx = randperm(numTrainImages,16);
figure
for i = 1:16
    subplot(4,4,i)
    I = readimage(imdsTrain,idx(i));
    imshow(I)
end

Figure contains 16 axes. Axes 1 contains an object of type image. Axes 2 contains an object of type image. Axes 3 contains an object of type image. Axes 4 contains an object of type image. Axes 5 contains an object of type image. Axes 6 contains an object of type image. Axes 7 contains an object of type image. Axes 8 contains an object of type image. Axes 9 contains an object of type image. Axes 10 contains an object of type image. Axes 11 contains an object of type image. Axes 12 contains an object of type image. Axes 13 contains an object of type image. Axes 14 contains an object of type image. Axes 15 contains an object of type image. Axes 16 contains an object of type image.

加载预训练网络

加载预训练的 ResNet-18 网络。如果未安装 Deep Learning Toolbox Model for ResNet-18 Network 支持包,则软件会提供下载链接。ResNet-18 已基于超过一百万个图像进行训练,可以将图像分为 1000 个对象类别(例如键盘、鼠标、铅笔和多种动物)。因此,该模型已基于大量图像学习了丰富的特征表示。

net = resnet18
net = 
  DAGNetwork with properties:

         Layers: [71x1 nnet.cnn.layer.Layer]
    Connections: [78x2 table]
     InputNames: {'data'}
    OutputNames: {'ClassificationLayer_predictions'}

分析网络架构。第一层(图像输入层)需要大小为 224×224×3 的输入图像,其中 3 是颜色通道数。

inputSize = net.Layers(1).InputSize;
analyzeNetwork(net)

提取图像特征

网络要求输入图像的大小为 224×224×3,但图像数据存储中的图像具有不同大小。要在将训练图像和测试图像输入到网络之前自动调整它们的大小,请创建增强的图像数据存储,指定所需的图像大小,并将这些数据存储用作 activations 的输入参数。

augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain);
augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest);

网络构造输入图像的分层表示。更深层包含更高级别的特征,这些特征使用较浅层的较低级别特征构建。要获得训练图像和测试图像的特征表示,请对网络末尾的全局池化层 'pool5', 使用 activations。全局池化层汇集所有空间位置的输入特征,总共提供 512 个特征。

layer = 'pool5';
featuresTrain = activations(net,augimdsTrain,layer,'OutputAs','rows');
featuresTest = activations(net,augimdsTest,layer,'OutputAs','rows');

whos featuresTrain
  Name                Size              Bytes  Class     Attributes

  featuresTrain      55x512            112640  single              

从训练数据和测试数据中提取类标签。

YTrain = imdsTrain.Labels;
YTest = imdsTest.Labels;

拟合图像分类器

使用从训练图像中提取的特征作为预测变量,并使用 fitcecoc (Statistics and Machine Learning Toolbox) 拟合多类支持向量机 (SVM)。

classifier = fitcecoc(featuresTrain,YTrain);

对测试图像进行分类

使用经过训练的 SVM 模型和从测试图像中提取的特征对测试图像进行分类。

YPred = predict(classifier,featuresTest);

显示四个示例测试图像及预测的标签。

idx = [1 5 10 15];
figure
for i = 1:numel(idx)
    subplot(2,2,i)
    I = readimage(imdsTest,idx(i));
    label = YPred(idx(i));
    imshow(I)
    title(char(label))
end

Figure contains 4 axes. Axes 1 with title MathWorks Cap contains an object of type image. Axes 2 with title MathWorks Cube contains an object of type image. Axes 3 with title MathWorks Playing Cards contains an object of type image. Axes 4 with title MathWorks Screwdriver contains an object of type image.

计算针对测试集的分类准确度。准确度是网络预测正确的标签的比例。

accuracy = mean(YPred == YTest)
accuracy = 1

基于较浅特征训练分类器

您还可以从网络的较浅层提取特征,并基于这些特征训练分类器。较浅的层通常具有较高的空间分辨率和较大的激活总数,提取的特征也较少、较浅。从 'res3b_relu' 层中提取特征。这是输出 128 个特征的最终层,激活的空间大小为 28×28。

layer = 'res3b_relu';
featuresTrain = activations(net,augimdsTrain,layer);
featuresTest = activations(net,augimdsTest,layer);
whos featuresTrain
  Name                Size                      Bytes  Class     Attributes

  featuresTrain      28x28x128x55            22077440  single              

此示例第一部分中使用的提取特征是从全局池化层的所有空间位置汇集而来的。要在从较浅层中提取特征时获得相同的结果,请手动对所有空间位置的激活区域求平均。要获得 N×C 形式的特征,其中 N 是观测值数目,C 是特征数量,请删除单一维度并转置。

featuresTrain = squeeze(mean(featuresTrain,[1 2]))';
featuresTest = squeeze(mean(featuresTest,[1 2]))';
whos featuresTrain
  Name                Size             Bytes  Class     Attributes

  featuresTrain      55x128            28160  single              

基于较浅特征训练 SVM 分类器。计算测试准确度。

classifier = fitcecoc(featuresTrain,YTrain);
YPred = predict(classifier,featuresTest);
accuracy = mean(YPred == YTest)
accuracy = 0.9500

两个经过训练的 SVM 都具有高准确度。如果使用特征提取时的准确度不够高,则尝试迁移学习。有关示例,请参阅训练深度学习网络以对新图像进行分类。有关预训练网络的列表和比较,请参阅预训练的深度神经网络

另请参阅

(Statistics and Machine Learning Toolbox) |

相关主题