将分类网络转换为回归网络
此示例说明如何将经过训练的分类网络转换为回归网络。
预训练的图像分类网络已经对超过一百万个图像进行了训练,可以将图像分为 1000 个对象类别,例如键盘、咖啡杯、铅笔和多种动物。这些网络已基于大量图像学习了丰富的特征表示。网络以图像作为输入,然后输出图像中对象的标签以及每个对象类别的概率。
深度学习应用中常常用到迁移学习。您可以采用预训练的网络,基于它学习新任务。此示例说明如何加载预训练的分类网络,以及如何重新训练该网络以用于回归任务。
在此示例中,我们会加载一个预训练的用于分类的卷积神经网络架构,然后替换用于分类的层并重新训练网络,以预测手写数字的旋转角度。您还可以选择使用 imrotate
(Image Processing Toolbox™),利用预测值来校正图像旋转。
加载预训练网络
从支持文件 digitsNet.mat
中加载预训练网络。此文件包含对手写数字进行分类的分类网络。
load digitsNet
layers = net.Layers
layers = 15x1 Layer array with layers: 1 'imageinput' Image Input 28x28x1 images with 'zerocenter' normalization 2 'conv_1' 2-D Convolution 8 3x3x1 convolutions with stride [1 1] and padding 'same' 3 'batchnorm_1' Batch Normalization Batch normalization with 8 channels 4 'relu_1' ReLU ReLU 5 'maxpool_1' 2-D Max Pooling 2x2 max pooling with stride [2 2] and padding [0 0 0 0] 6 'conv_2' 2-D Convolution 16 3x3x8 convolutions with stride [1 1] and padding 'same' 7 'batchnorm_2' Batch Normalization Batch normalization with 16 channels 8 'relu_2' ReLU ReLU 9 'maxpool_2' 2-D Max Pooling 2x2 max pooling with stride [2 2] and padding [0 0 0 0] 10 'conv_3' 2-D Convolution 32 3x3x16 convolutions with stride [1 1] and padding 'same' 11 'batchnorm_3' Batch Normalization Batch normalization with 32 channels 12 'relu_3' ReLU ReLU 13 'fc' Fully Connected 10 fully connected layer 14 'softmax' Softmax softmax 15 'classoutput' Classification Output crossentropyex with '0' and 9 other classes
加载数据
数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。
使用 digitTrain4DArrayData
和 digitTest4DArrayData
以四维数组的形式加载训练图像和验证图像。输出 YTrain
和 YValidation
是以度为单位的旋转角度。训练数据集和验证数据集各包含 5000 个图像。
[XTrain,~,YTrain] = digitTrain4DArrayData; [XValidation,~,YValidation] = digitTest4DArrayData;
使用 imshow
显示 20 个随机训练图像。
numTrainImages = numel(YTrain); figure idx = randperm(numTrainImages,20); for i = 1:numel(idx) subplot(4,5,i) imshow(XTrain(:,:,:,idx(i))) end
替换最终层
网络的卷积层会提取最后一个可学习层和最终分类层用来对输入图像进行分类的图像特征。digitsNet
中的 'fc'
和 'classoutput'
这两个层包含有关如何将网络提取的特征合并成类概率、损失值和预测标签的信息。要重新训练一个预训练网络以用于回归任务,需要将这两个层替换为适用于该任务的新层。
将最终全连接层(softmax 层)和分类输出层替换为大小为 1(响应数)的全连接层和回归层。
numResponses = 1; layers = [ layers(1:12) fullyConnectedLayer(numResponses) regressionLayer];
冻结初始层
现在,网络已准备好可以基于新数据进行重新训练。您也可以选择将较浅网络层的学习率设置为零,来“冻结”这些层的权重。在训练过程中,trainNetwork
不会更新已冻结层的参数。由于不需要计算已冻结层的梯度,因此冻结多个初始层的权重可以显著加快网络训练速度。如果新数据集很小,冻结较浅的网络层还可以防止这些层对新数据集过拟合。
使用辅助函数 freezeWeights
将前 12 个层的学习率设置为零。
layers(1:12) = freezeWeights(layers(1:12));
训练网络
创建网络训练选项。将初始学习率设置为 0.001。通过指定验证数据,监控训练过程中的网络准确度。打开训练进度图,关闭命令行窗口输出。
options = trainingOptions('sgdm',... 'InitialLearnRate',0.001, ... 'ValidationData',{XValidation,YValidation},... 'Plots','training-progress',... 'Verbose',false);
使用 trainNetwork
创建网络。如果存在兼容的 GPU,此命令会使用 GPU。使用 GPU 需要 Parallel Computing Toolbox™ 和支持的 GPU 设备。有关受支持设备的信息,请参阅GPU Computing Requirements (Parallel Computing Toolbox)。否则,trainNetwork
将使用 CPU。
net = trainNetwork(XTrain,YTrain,layers,options);
测试网络
基于验证数据评估准确度来测试网络性能。
使用 predict
预测验证图像的旋转角度。
YPred = predict(net,XValidation);
通过计算以下值来评估模型性能:
在可接受误差界限内的预测值的百分比
预测旋转角度和实际旋转角度的均方根误差 (RMSE)
计算预测旋转角度和实际旋转角度之间的预测误差。
predictionError = YValidation - YPred;
计算在实际角度的可接受误差界限内的预测值的数量。将阈值设置为 10 度。计算此阈值范围内的预测值的百分比。
thr = 10; numCorrect = sum(abs(predictionError) < thr); numImagesValidation = numel(YValidation); accuracy = numCorrect/numImagesValidation
accuracy = 0.7532
使用均方根误差 (RMSE) 来衡量预测旋转角度和实际旋转角度之间的差异。
rmse = sqrt(mean(predictionError.^2))
rmse = single
9.0270
校正数字旋转
您可以使用 Image Processing Toolbox 中的函数来摆正数字并将它们显示在一起。使用 imrotate
(Image Processing Toolbox) 根据预测的旋转角度旋转 49 个样本数字。
idx = randperm(numImagesValidation,49); for i = 1:numel(idx) I = XValidation(:,:,:,idx(i)); Y = YPred(idx(i)); XValidationCorrected(:,:,:,i) = imrotate(I,Y,'bicubic','crop'); end
显示原始数字以及校正旋转后的数字。使用 montage
(Image Processing Toolbox) 将数字一起显示在一个图像上。
figure subplot(1,2,1) montage(XValidation(:,:,:,idx)) title('Original') subplot(1,2,2) montage(XValidationCorrected) title('Corrected')
另请参阅
regressionLayer
| classificationLayer