本页对应的英文页面已更新,但尚未翻译。 若要查看最新内容,请点击此处访问英文页面。

Deep Network Designer 快速入门

此示例说明如何微调预训练的 GoogLeNet 网络以对新的图像集合进行分类。此过程称为迁移学习,通常比训练新网络更快更容易,因为您可以使用较少数量的训练图像将已学习的特征应用于新任务。要以交互方式准备用于迁移学习的网络,请使用 Deep Network Designer。

加载预训练网络

加载预训练的 GoogLeNet 网络。如果需要下载网络,请使用下载链接。

net = googlenet;

将网络导入 Deep Network Designer

打开 Deep Network Designer。

deepNetworkDesigner

点击导入,然后从工作区选择网络。Deep Network Designer 将显示整个网络的缩小视图。浏览网络图。要使用鼠标放大,请使用 Ctrl + 滚轮。

编辑迁移学习网络

要对预训练网络进行重新训练以对新图像进行分类,请将最终层替换为适合新数据集的新层。您必须更改类的数量以匹配您的数据。

将新的 fullyConnectedLayer网络层库拖到画布上。将 OutputSize 编辑为新数据中的类数,此示例中为 5。

编辑学习率,以使新层中的学习速度快于迁移层的学习速度。将 WeightLearnRateFactorBiasLearnRateFactor 设置为 10。删除最后一个全连接层,改为连接新层。

替换输出层。滚动到网络层库的末尾,将一个新的 classificationLayer 拖到画布上。删除原来的 output 层,改为连接新层。

检查网络

要确保编辑后的网络已准备好训练,请点击分析,并确保 Deep Learning Network Analyzer 报告零错误。

导出训练网络

返回到 Deep Network Designer,然后点击导出。Deep Network Designer 将网络导出到名为 lgraph_1 的新变量,该变量包含已编辑的网络层。您现在可以将层变量提供给 trainNetwork 函数。您还可以生成 MATLAB® 代码,以在 MATLAB 工作区中重新创建网络架构并将其以 layerGraph 对象或 Layer 数组的形式返回。

加载数据并训练网络

解压缩新图像并加载这些图像作为图像数据存储。将数据分为 70% 用作训练数据,30% 用作验证数据。

unzip('MerchData.zip');
imds = imageDatastore('MerchData','IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');

调整图像大小,以匹配预训练网络输入大小。

augimdsTrain = augmentedImageDatastore([224 224],imdsTrain);
augimdsValidation = augmentedImageDatastore([224 224],imdsValidation);

指定训练选项。

  • 指定小批量大小,即每次迭代中使用多少个图像。

  • 指定少量轮数。一轮训练是对整个训练数据集的一个完整训练周期。对于迁移学习,所需的训练轮数相对较少。每轮训练都会打乱数据。

  • InitialLearnRate 设置为较小的值以减慢迁移层中的学习速度。

  • 指定验证数据和较小的验证频率。

  • 打开训练图,以在训练时监控进度。

options = trainingOptions('sgdm', ...
    'MiniBatchSize',10, ...
    'MaxEpochs',6, ...
    'Shuffle','every-epoch', ...
    'InitialLearnRate',1e-4, ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',6, ...
    'Verbose',false, ...
    'Plots','training-progress');

要训练网络,请将从 App 导出的层 lgraph_1、训练图像和选项提供给 trainNetwork 函数。默认情况下,如果有 GPU 可用,trainNetwork 就会使用 GPU(需要 Parallel Computing Toolbox™)。否则,将使用 CPU。由于数据集很小,因此训练很快。

netTransfer = trainNetwork(augimdsTrain,lgraph_1,options);

测试经过训练的网络

使用经过微调的网络对验证图像进行分类,并计算分类准确度。

[YPred,probs] = classify(netTransfer,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
accuracy = 1

显示四个示例验证图像及预测的标签以及预测的概率。

idx = randperm(numel(augimdsValidation.Files),4);
figure
for i = 1:4
    subplot(2,2,i)
    I = readimage(imdsValidation,idx(i));
    imshow(I)
    label = YPred(idx(i));
    title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%");
end

要了解详细信息并尝试其他预训练网络,请参阅 Deep Network Designer

另请参阅

相关主题