Main Content

图像分类快速入门

此示例说明如何使用深度网络设计器创建简单的卷积神经网络来进行深度学习分类。卷积神经网络是深度学习的基本工具,尤其适用于图像识别。

加载图像数据

将数字样本数据加载为图像数据存储。要访问此数据,请以实时脚本形式打开此示例。imageDatastore 函数根据文件夹名称自动对图像加标签。数据集有 10 个类,数据集中每个图像的像素数为 28×28×1。

unzip("DigitsData.zip")

imds = imageDatastore("DigitsData", ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

classNames = categories(imds.Labels);

将数据划分为训练、验证和测试数据集。将 70% 的图像用于训练,15% 的图像用于验证,15% 的图像用于测试。指定 "randomized" 以将每个类中指定比例的文件分配给新数据集。splitEachLabel 函数将图像数据存储拆分为三个新数据存储。

[imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.7,0.15,0.15,"randomized");

定义网络架构

要构建网络,请使用深度网络设计器。

deepNetworkDesigner

要创建一个空白网络,请在空白网络上暂停,然后点击新建

设计器窗格中,定义卷积神经网络架构。从网络层库中拖动层并连接它们。要快速搜索层,请使用网络层库窗格中的过滤层搜索框。要编辑层的属性,请点击该层,然后在属性窗格中编辑值。

按顺序拖动这些层并按顺序连接它们。首先,将 imageInputLayer 拖到画布上,并将 InputSize 设置为 28,28,1

接下来,将这些层拖到画布上并按顺序连接它们:

  • convolution2dLayer

  • batchNormalizationLayer

  • reluLayer

然后,连接 fullyConnectedLayer,并将 OutputSize 设置为数据中的类数,在此示例中为 10。

最后,添加 softmaxLayer

要检查网络是否准备好进行训练,请点击分析。深度学习网络分析器报告零错误或警告,因此,网络已准备就绪,可以开始进行训练。要导出网络,请点击导出。该 App 将网络保存为变量 net_1

指定训练选项

指定训练选项。在选项中进行选择需要经验分析。要通过运行试验探索不同训练选项配置,您可以使用Experiment Manager

options = trainingOptions("sgdm", ...
    MaxEpochs=4, ...
    ValidationData=imdsValidation, ...
    ValidationFrequency=30, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

训练神经网络

使用 trainnet 函数训练神经网络。由于目的是分类,因此使用交叉熵损失。

net = trainnet(imdsTrain,net_1,"crossentropy",options);

测试神经网络

要测试神经网络,请对验证数据进行分类,并计算分类准确度。

使用 minibatchpredict 函数进行预测,并使用 scores2label 函数将分数转换为标签。默认情况下,minibatchpredict 函数使用 GPU(如果有)。

scores = minibatchpredict(net,imdsValidation);
YValidation = scores2label(scores,classNames);

计算分类准确度。准确度是正确预测的标签的百分比。

TValidation = imdsValidation.Labels;
accuracy = mean(YValidation == TValidation)
accuracy = 0.9780

可视化一些预测值。

numValidationObservations = numel(imdsValidation.Files);
idx = randi(numValidationObservations,9,1);

figure
tiledlayout("flow")
for i = 1:9
    nexttile
    img = readimage(imdsValidation,idx(i));
    imshow(img)
    title("Predicted Class: " + string(YValidation(idx(i))))
end

在深度学习的后续步骤中,您可以尝试使用预训练网络和迁移学习。有关示例,请参阅迁移学习快速入门

另请参阅

|

相关主题