本页对应的英文页面已更新,但尚未翻译。 若要查看最新内容,请点击此处访问英文页面。

自定义深度学习网络训练期间的输出

此示例说明如何定义在深度学习神经网络训练期间的每次迭代都运行的输出函数。如果您使用 trainingOptions'OutputFcn' 名称-值对组参数指定输出函数,则 trainNetwork 将分别在训练开始前、每次训练迭代后以及训练结束后各调用这些函数一次。每次调用输出函数时,trainNetwork 都会传递一个包含当前迭代编号、损失和准确度等信息的结构体。您可以使用输出函数显示或绘制进度信息,或者停止训练。要提前停止训练,请让输出函数返回 true。如果任何输出函数返回 true,则训练结束,并且 trainNetwork 返回最新网络。

要在验证集的损失不再降低时停止训练,只需分别使用 trainingOptions'ValidationData''ValidationPatience' 名称-值对组参数指定验证数据和验证容忍度。验证容忍度是指在网络训练停止之前验证集的损失可以大于或等于先前最小损失的次数。您可以使用输出函数添加其他停止条件。此示例说明如何创建在验证数据的分类准确度不再提高时停止训练的输出函数。输出函数在脚本末尾定义。

加载训练数据,其中包含 5000 个数字图像。留出 1000 个图像用于网络验证。

[XTrain,YTrain] = digitTrain4DArrayData;

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

构建网络以对数字图像数据进行分类。

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

指定网络训练的选项。要在训练过程中按固定时间间隔验证网络,请指定验证数据。选择 'ValidationFrequency' 值,以便每轮训练都验证一次网络。

要在验证集的分类准确度不再提高时停止训练,请将 stopIfAccuracyNotImproving 指定为输出函数。stopIfAccuracyNotImproving 的第二个输入参数是在网络训练停止之前验证集的准确度可以小于或等于先前最高准确度的次数。为最大训练轮数指定一个较大的值。训练不应到达最后一轮,它应该会自动停止。

miniBatchSize = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'VerboseFrequency',validationFrequency, ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'OutputFcn',@(info)stopIfAccuracyNotImproving(info,3));

训练网络。当验证准确度停止升高时,训练停止。

net = trainNetwork(XTrain,YTrain,layers,options);
Training on single CPU.
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:02 |       10.94% |       13.20% |       2.9520 |       2.5400 |          0.0100 |
|       1 |          31 |       00:00:11 |       71.09% |       75.40% |       0.8453 |       0.8356 |          0.0100 |
|       2 |          62 |       00:00:25 |       91.41% |       89.20% |       0.3514 |       0.4304 |          0.0100 |
|       3 |          93 |       00:00:42 |       96.88% |       94.20% |       0.1887 |       0.2572 |          0.0100 |
|       4 |         124 |       00:01:01 |       99.22% |       96.20% |       0.1189 |       0.1927 |          0.0100 |
|       5 |         155 |       00:01:23 |      100.00% |       96.80% |       0.0880 |       0.1566 |          0.0100 |
|       6 |         186 |       00:01:41 |      100.00% |       97.10% |       0.0614 |       0.1226 |          0.0100 |
|       7 |         217 |       00:01:56 |       99.22% |       97.90% |       0.0566 |       0.1017 |          0.0100 |
|       8 |         248 |       00:02:08 |       99.22% |       98.20% |       0.0476 |       0.0863 |          0.0100 |
|       9 |         279 |       00:02:25 |      100.00% |       98.60% |       0.0334 |       0.0740 |          0.0100 |
|      10 |         310 |       00:02:37 |      100.00% |       98.80% |       0.0267 |       0.0645 |          0.0100 |
|      11 |         341 |       00:02:47 |      100.00% |       98.80% |       0.0226 |       0.0567 |          0.0100 |
|      12 |         372 |       00:02:58 |      100.00% |       99.20% |       0.0195 |       0.0503 |          0.0100 |
|      13 |         403 |       00:03:05 |      100.00% |       99.30% |       0.0171 |       0.0453 |          0.0100 |
|      14 |         434 |       00:03:13 |      100.00% |       99.40% |       0.0154 |       0.0417 |          0.0100 |
|      15 |         465 |       00:03:20 |      100.00% |       99.50% |       0.0142 |       0.0391 |          0.0100 |
|      16 |         496 |       00:03:27 |      100.00% |       99.50% |       0.0131 |       0.0371 |          0.0100 |
|      17 |         527 |       00:03:38 |      100.00% |       99.50% |       0.0122 |       0.0355 |          0.0100 |
|      18 |         558 |       00:03:45 |      100.00% |       99.50% |       0.0114 |       0.0343 |          0.0100 |
|======================================================================================================================|

定义输出函数

定义输出函数 stopIfAccuracyNotImproving(info,N) 该函数在验证数据的最优分类准确度连续 N 次网络验证都没有提高时停止网络训练。此标准类似于使用验证损失的内置停止条件,只不过它适用于分类准确度而不是损失。

function stop = stopIfAccuracyNotImproving(info,N)

stop = false;

% Keep track of the best validation accuracy and the number of validations for which
% there has not been an improvement of the accuracy.
persistent bestValAccuracy
persistent valLag

% Clear the variables when training starts.
if info.State == "start"
    bestValAccuracy = 0;
    valLag = 0;
    
elseif ~isempty(info.ValidationLoss)
    
    % Compare the current validation accuracy to the best accuracy so far,
    % and either set the best accuracy to the current accuracy, or increase
    % the number of validations for which there has not been an improvement.
    if info.ValidationAccuracy > bestValAccuracy
        valLag = 0;
        bestValAccuracy = info.ValidationAccuracy;
    else
        valLag = valLag + 1;
    end
    
    % If the validation lag is at least N, that is, the validation accuracy
    % has not improved for at least N validations, then return true and
    % stop training.
    if valLag >= N
        stop = true;
    end
    
end

end

另请参阅

|

相关主题