Main Content

创建简单的图像分类网络

此示例说明如何创建和训练简单的卷积神经网络来进行深度学习分类。卷积神经网络是深度学习的基本工具,尤其适用于图像识别。

该示例演示如何:

  • 加载图像数据。

  • 定义网络架构。

  • 指定训练选项。

  • 训练网络。

  • 预测新数据的标签并计算分类准确度。

有关如何以交互方式创建和训练简单图像分类网络的示例,请参阅图像分类快速入门

加载数据

解压缩数字样本数据并创建一个图像数据存储。imageDatastore 函数根据文件夹名称自动对图像加标签。

unzip("DigitsData.zip")
imds = imageDatastore("DigitsData", ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

将数据划分为训练数据集和验证数据集,以使训练集中的每个类别包含 750 个图像,并且验证集包含对应每个标签的其余图像。splitEachLabel 将图像数据存储拆分为两个新的数据存储以用于训练和验证。

numTrainFiles = 750;
[imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles,"randomized");

查看类名称。

classNames = categories(imdsTrain.Labels)
classNames = 10x1 cell
    {'0'}
    {'1'}
    {'2'}
    {'3'}
    {'4'}
    {'5'}
    {'6'}
    {'7'}
    {'8'}
    {'9'}

定义网络架构

定义卷积神经网络架构。指定网络输入层中图像的大小和全连接层中类的数量。每个图像为 28×28×1 像素,有 10 个类。

inputSize = [28 28 1];
numClasses = 10;

layers = [
    imageInputLayer(inputSize)
    convolution2dLayer(5,20)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

有关深度学习层的详细信息,请参阅深度学习层列表

指定训练选项

指定训练选项。在选项中进行选择需要经验分析。要通过运行试验探索不同训练选项配置,您可以使用Experiment Manager

options = trainingOptions("sgdm", ...
    MaxEpochs=4, ...
    ValidationData=imdsValidation, ...
    ValidationFrequency=30, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

训练神经网络

使用 trainnet 函数训练神经网络。对于分类,使用交叉熵损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。有关受支持设备的信息,请参阅GPU Computing Requirements (Parallel Computing Toolbox)。否则,该函数使用 CPU。要指定执行环境,请使用 ExecutionEnvironment 训练选项。

net = trainnet(imdsTrain,layers,"crossentropy",options);

测试神经网络

要测试神经网络,请对验证数据进行分类,并计算分类准确度。

使用 minibatchpredict 函数进行预测,并使用 scores2label 函数将分数转换为标签。默认情况下,minibatchpredict 函数使用 GPU(如果有)。

scores = minibatchpredict(net,imdsValidation);
YValidation = scores2label(scores,classNames);

计算分类准确度。准确度是正确预测的标签的百分比。

TValidation = imdsValidation.Labels;
accuracy = mean(YValidation == TValidation)
accuracy = 0.9896

在深度学习的后续步骤中,您可以尝试将预训练网络用于其他任务。通过迁移学习或特征提取解决新的图像数据分类问题。有关示例,请参阅使用迁移学习更快地开始深度学习使用从预训练网络中提取的特征训练分类器。要了解有关预训练网络的详细信息,请参阅预训练的深度神经网络

另请参阅

| |

相关主题