Main Content

本页翻译不是最新的。点击此处可查看最新英文版本。

将分类网络转换为回归网络

此示例说明如何将经过训练的分类网络转换为回归网络。

预训练的图像分类网络已经对超过一百万个图像进行了训练,可以将图像分为 1000 个对象类别,例如键盘、咖啡杯、铅笔和多种动物。这些网络已基于大量图像学习了丰富的特征表示。网络以图像作为输入,然后输出图像中对象的标签以及每个对象类别的概率。

深度学习应用中常常用到迁移学习。您可以采用预训练的网络,基于它学习新任务。此示例说明如何加载预训练的分类网络,以及如何重新训练该网络以用于回归任务。

在此示例中,我们会加载一个预训练的用于分类的卷积神经网络架构,然后替换用于分类的层并重新训练网络,以预测手写数字的旋转角度。您还可以选择使用 imrotate (Image Processing Toolbox™),利用预测值来校正图像旋转。

加载预训练网络

从支持文件 digitsNet.mat 中加载预训练网络。此文件包含对手写数字进行分类的分类网络。

load digitsNet
layers = net.Layers
layers = 
  15x1 Layer array with layers:

     1   'imageinput'    Image Input             28x28x1 images with 'zerocenter' normalization
     2   'conv_1'        Convolution             8 3x3x1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'   Batch Normalization     Batch normalization with 8 channels
     4   'relu_1'        ReLU                    ReLU
     5   'maxpool_1'     Max Pooling             2x2 max pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'        Convolution             16 3x3x8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'   Batch Normalization     Batch normalization with 16 channels
     8   'relu_2'        ReLU                    ReLU
     9   'maxpool_2'     Max Pooling             2x2 max pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'        Convolution             32 3x3x16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'   Batch Normalization     Batch normalization with 32 channels
    12   'relu_3'        ReLU                    ReLU
    13   'fc'            Fully Connected         10 fully connected layer
    14   'softmax'       Softmax                 softmax
    15   'classoutput'   Classification Output   crossentropyex with '0' and 9 other classes

加载数据

数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。

使用 digitTrain4DArrayDatadigitTest4DArrayData 以四维数组的形式加载训练图像和验证图像。输出 YTrain YValidation 是以度为单位的旋转角度。训练数据集和验证数据集各包含 5000 个图像。

[XTrain,~,YTrain] = digitTrain4DArrayData;
[XValidation,~,YValidation] = digitTest4DArrayData;

使用 imshow 显示 20 个随机训练图像。

numTrainImages = numel(YTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

Figure contains 20 axes objects. Axes object 1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

替换最终层

网络的卷积层会提取最后一个可学习层和最终分类层用来对输入图像进行分类的图像特征。digitsNet 中的 'fc''classoutput' 这两个层包含有关如何将网络提取的特征合并成类概率、损失值和预测标签的信息。要重新训练一个预训练网络以用于回归任务,需要将这两个层替换为适用于该任务的新层。

将最终全连接层(softmax 层)和分类输出层替换为大小为 1(响应数)的全连接层和回归层。

numResponses = 1;
layers = [
    layers(1:12)
    fullyConnectedLayer(numResponses)
    regressionLayer];

冻结初始层

现在,网络已准备好可以基于新数据进行重新训练。您也可以选择将较浅网络层的学习率设置为零,来“冻结”这些层的权重。在训练过程中,trainNetwork 不会更新已冻结层的参数。由于不需要计算已冻结层的梯度,因此冻结多个初始层的权重可以显著加快网络训练速度。如果新数据集很小,冻结较浅的网络层还可以防止这些层对新数据集过拟合。

使用辅助函数 freezeWeights 将前 12 个层的学习率设置为零。

layers(1:12) = freezeWeights(layers(1:12));

训练网络

创建网络训练选项。将初始学习率设置为 0.001。通过指定验证数据,监控训练过程中的网络准确度。打开训练进度图,关闭命令行窗口输出。

options = trainingOptions('sgdm',...
    'InitialLearnRate',0.001, ...
    'ValidationData',{XValidation,YValidation},...
    'Plots','training-progress',...
    'Verbose',false);

使用 trainNetwork 创建网络。如果存在兼容的 GPU,此命令会使用 GPU。使用 GPU 需要 Parallel Computing Toolbox™ 和支持的 GPU 设备。有关受支持设备的信息,请参阅GPU Computing Requirements (Parallel Computing Toolbox)。否则,trainNetwork 将使用 CPU。

net = trainNetwork(XTrain,YTrain,layers,options);

Figure Training Progress (08-Feb-2022 03:30:36) contains 2 axes objects and another object of type uigridlayout. Axes object 1 contains 10 objects of type patch, text, line. Axes object 2 contains 10 objects of type patch, text, line.

测试网络

基于验证数据评估准确度来测试网络性能。

使用 predict 预测验证图像的旋转角度。

YPred = predict(net,XValidation);

通过计算以下值来评估模型性能:

  1. 在可接受误差界限内的预测值的百分比

  2. 预测旋转角度和实际旋转角度的均方根误差 (RMSE)

计算预测旋转角度和实际旋转角度之间的预测误差。

predictionError = YValidation - YPred;

计算在实际角度的可接受误差界限内的预测值的数量。将阈值设置为 10 度。计算此阈值范围内的预测值的百分比。

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numImagesValidation = numel(YValidation);

accuracy = numCorrect/numImagesValidation
accuracy = 0.7532

使用均方根误差 (RMSE) 来衡量预测旋转角度和实际旋转角度之间的差异。

 rmse = sqrt(mean(predictionError.^2))
rmse = single
    9.0270

校正数字旋转

您可以使用 Image Processing Toolbox 中的函数来摆正数字并将它们显示在一起。使用 imrotate (Image Processing Toolbox) 根据预测的旋转角度旋转 49 个样本数字。

idx = randperm(numImagesValidation,49);
for i = 1:numel(idx)
    I = XValidation(:,:,:,idx(i));
    Y = YPred(idx(i));  
    XValidationCorrected(:,:,:,i) = imrotate(I,Y,'bicubic','crop');
end

显示原始数字以及校正旋转后的数字。使用 montage (Image Processing Toolbox) 将数字一起显示在一个图像上。

figure
subplot(1,2,1)
montage(XValidation(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(XValidationCorrected)
title('Corrected')

Figure contains 2 axes objects. Axes object 1 with title Original contains an object of type image. Axes object 2 with title Corrected contains an object of type image.

另请参阅

|

相关主题