Main Content

将分类网络转换为回归网络

此示例说明如何将经过训练的分类网络转换为回归网络。

预训练的图像分类网络已经对超过一百万个图像进行了训练,可以将图像分为 1000 个对象类别,例如键盘、咖啡杯、铅笔和多种动物。这些网络已基于大量图像学习了丰富的特征表示。网络以图像作为输入,然后输出图像中对象的标签以及每个对象类别的概率。

深度学习应用中常常用到迁移学习。您可以采用预训练的网络,基于它学习新任务。此示例说明如何加载预训练的分类网络,以及如何重新训练该网络以用于回归任务。

在此示例中,我们会加载一个预训练的用于分类的卷积神经网络架构,然后替换用于分类的层并重新训练网络,以预测手写数字的旋转角度。

加载预训练网络

从支持文件 digitsClassificationConvolutionNet.mat 中加载预训练网络。此文件包含对手写数字进行分类的分类网络。

load digitsClassificationConvolutionNet
layers = net.Layers
layers = 
  13x1 Layer array with layers:

     1   'imageinput'    Image Input                  28x28x1 images
     2   'conv_1'        2-D Convolution              10 3x3x1 convolutions with stride [2  2] and padding [0  0  0  0]
     3   'batchnorm_1'   Batch Normalization          Batch normalization with 10 channels
     4   'relu_1'        ReLU                         ReLU
     5   'conv_2'        2-D Convolution              20 3x3x10 convolutions with stride [2  2] and padding [0  0  0  0]
     6   'batchnorm_2'   Batch Normalization          Batch normalization with 20 channels
     7   'relu_2'        ReLU                         ReLU
     8   'conv_3'        2-D Convolution              40 3x3x20 convolutions with stride [2  2] and padding [0  0  0  0]
     9   'batchnorm_3'   Batch Normalization          Batch normalization with 40 channels
    10   'relu_3'        ReLU                         ReLU
    11   'gap'           2-D Global Average Pooling   2-D global average pooling
    12   'fc'            Fully Connected              10 fully connected layer
    13   'softmax'       Softmax                      softmax

加载数据

数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。

从支持文件 DigitsDataTrain.matDigitsDataTest.mat 中将训练图像和测试图像作为四维数组加载。变量 anglesTrainanglesTest 是以度为单位的旋转角度。训练数据集和测试数据集各包含 5000 个图像。

load DigitsDataTrain
load DigitsDataTest

使用 imshow 显示 20 个随机训练图像。

numTrainImages = numel(anglesTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

替换最终层

网络的卷积层会提取最后一个可学习层用来对输入图像进行分类的图像特征。层 'fc' 包含有关如何将网络提取到类概率中的特征进行组合的信息。要重新训练一个预训练网络以用于回归任务,请用适合此任务的新层替换此层和其后的 softmax 层。

将最终全连接层替换为大小为 1(响应数)的全连接层。

numResponses = 1;
layer = fullyConnectedLayer(numResponses,Name="fc");

net = replaceLayer(net,"fc",layer)
net = 
  dlnetwork with properties:

         Layers: [13x1 nnet.cnn.layer.Layer]
    Connections: [12x2 table]
     Learnables: [14x3 table]
          State: [6x3 table]
     InputNames: {'imageinput'}
    OutputNames: {'softmax'}
    Initialized: 0

  View summary with summary.

删除 softmax 层。

net = removeLayers(net,"softmax");

调整层学习率因子

现在,网络已准备好可以基于新数据进行重新训练。(可选)在指定训练选项时,可以通过提高新全连接层的学习率和降低全局学习率来减慢网络中较浅层的权重的训练。

使用 setLearnRateFactor 函数按因子提高全连接层参数的学习率。

net = setLearnRateFactor(net,"fc","Weights",10);
net = setLearnRateFactor(net,"fc","Bias",10);

指定训练选项

指定训练选项。在选项中进行选择需要经验分析。要通过运行试验探索不同训练选项配置,您可以使用Experiment Manager

  • 指定降低后的学习率为 0.0001。

  • 在绘图中显示训练进度。

  • 禁用详尽输出。

options = trainingOptions("sgdm",...
    InitialLearnRate=0.001, ...
    Plots="training-progress",...
    Verbose=false);

训练神经网络

使用 trainnet 函数训练神经网络。对于回归,请使用均方误差损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。有关受支持设备的信息,请参阅GPU Computing Requirements (Parallel Computing Toolbox)。否则,该函数使用 CPU。要指定执行环境,请使用 ExecutionEnvironment 训练选项。

net = trainnet(XTrain,anglesTrain,net,"mse",options);

测试网络

基于测试数据评估准确度来测试网络性能。

使用 predict 预测验证图像的旋转角度。

YTest = predict(net,XTest);

在散点图中可视化预测。绘制预测值对真实值的图。

figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")

hold on
plot([-60 60], [-60 60],"r--")

另请参阅

| |

相关主题