自定义深度学习网络训练期间的输出
此示例说明如何定义在深度学习神经网络训练期间的每次迭代都运行的输出函数。如果您使用 trainingOptions
的 'OutputFcn'
名称-值对组参数指定输出函数,则 trainNetwork
将分别在训练开始前、每次训练迭代后以及训练结束后各调用这些函数一次。每次调用输出函数时,trainNetwork
都会传递一个包含当前迭代序号、损失和准确度等信息的结构体。您可以使用输出函数显示或绘制进度信息,或者停止训练。要提前停止训练,请让输出函数返回 true
。如果任何输出函数返回 true
,则训练结束,并且 trainNetwork
返回最新网络。
要在验证集的损失不再降低时停止训练,只需分别使用 trainingOptions
的 'ValidationData'
和 'ValidationPatience'
名称-值对组参数指定验证数据和验证耐心值。验证耐心值是指在网络训练停止之前验证集的损失可以大于或等于先前最小损失的次数。您可以使用输出函数添加其他停止条件。此示例说明如何创建在验证数据的分类准确度不再提高时停止训练的输出函数。输出函数在脚本末尾定义。
加载训练数据,其中包含 5000 个数字图像。留出 1000 个图像用于网络验证。
[XTrain,YTrain] = digitTrain4DArrayData; idx = randperm(size(XTrain,4),1000); XValidation = XTrain(:,:,:,idx); XTrain(:,:,:,idx) = []; YValidation = YTrain(idx); YTrain(idx) = [];
构建网络以对数字图像数据进行分类。
layers = [ imageInputLayer([28 28 1]) convolution2dLayer(3,8,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,16,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer fullyConnectedLayer(10) softmaxLayer classificationLayer];
指定网络训练的选项。要在训练过程中按固定时间间隔验证网络,请指定验证数据。选择 'ValidationFrequency'
值,以便每轮训练都验证一次网络。
要在验证集的分类准确度不再提高时停止训练,请将 stopIfAccuracyNotImproving
指定为输出函数。stopIfAccuracyNotImproving
的第二个输入参数是在网络训练停止之前验证集的准确度可以小于或等于先前最高准确度的次数。为最大训练轮数指定一个较大的值。训练不应到达最后一轮,它应该会自动停止。
miniBatchSize = 128; validationFrequency = floor(numel(YTrain)/miniBatchSize); options = trainingOptions('sgdm', ... 'InitialLearnRate',0.01, ... 'MaxEpochs',100, ... 'MiniBatchSize',miniBatchSize, ... 'VerboseFrequency',validationFrequency, ... 'ValidationData',{XValidation,YValidation}, ... 'ValidationFrequency',validationFrequency, ... 'Plots','training-progress', ... 'OutputFcn',@(info)stopIfAccuracyNotImproving(info,3));
训练网络。当验证准确度停止升高时,训练停止。
net = trainNetwork(XTrain,YTrain,layers,options);
Training on single CPU. Initializing input data normalization. |======================================================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning | | | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate | |======================================================================================================================| | 1 | 1 | 00:00:04 | 7.81% | 12.70% | 2.7155 | 2.5169 | 0.0100 | | 1 | 31 | 00:00:10 | 71.88% | 74.90% | 0.8807 | 0.8130 | 0.0100 | | 2 | 62 | 00:00:15 | 86.72% | 88.00% | 0.3899 | 0.4436 | 0.0100 | | 3 | 93 | 00:00:22 | 94.53% | 94.00% | 0.2224 | 0.2553 | 0.0100 | | 4 | 124 | 00:00:30 | 95.31% | 96.80% | 0.1482 | 0.1762 | 0.0100 | | 5 | 155 | 00:00:35 | 98.44% | 97.60% | 0.1007 | 0.1314 | 0.0100 | | 6 | 186 | 00:00:41 | 99.22% | 97.80% | 0.0784 | 0.1136 | 0.0100 | | 7 | 217 | 00:00:48 | 100.00% | 98.10% | 0.0559 | 0.0945 | 0.0100 | | 8 | 248 | 00:00:53 | 100.00% | 98.00% | 0.0441 | 0.0859 | 0.0100 | | 9 | 279 | 00:01:01 | 100.00% | 98.00% | 0.0344 | 0.0786 | 0.0100 | | 10 | 310 | 00:01:08 | 100.00% | 98.50% | 0.0274 | 0.0678 | 0.0100 | | 11 | 341 | 00:01:14 | 100.00% | 98.50% | 0.0240 | 0.0621 | 0.0100 | | 12 | 372 | 00:01:20 | 100.00% | 98.70% | 0.0213 | 0.0569 | 0.0100 | | 13 | 403 | 00:01:25 | 100.00% | 98.80% | 0.0187 | 0.0534 | 0.0100 | | 14 | 434 | 00:01:30 | 100.00% | 98.80% | 0.0164 | 0.0508 | 0.0100 | | 15 | 465 | 00:01:37 | 100.00% | 98.90% | 0.0144 | 0.0487 | 0.0100 | | 16 | 496 | 00:01:42 | 100.00% | 99.00% | 0.0126 | 0.0462 | 0.0100 | | 17 | 527 | 00:01:49 | 100.00% | 98.90% | 0.0112 | 0.0440 | 0.0100 | | 18 | 558 | 00:01:54 | 100.00% | 98.90% | 0.0101 | 0.0420 | 0.0100 | | 19 | 589 | 00:01:59 | 100.00% | 99.10% | 0.0092 | 0.0405 | 0.0100 | | 20 | 620 | 00:02:04 | 100.00% | 99.00% | 0.0086 | 0.0391 | 0.0100 | | 21 | 651 | 00:02:09 | 100.00% | 99.00% | 0.0080 | 0.0380 | 0.0100 | | 22 | 682 | 00:02:15 | 100.00% | 99.00% | 0.0076 | 0.0369 | 0.0100 | |======================================================================================================================| Training finished: Stopped by OutputFcn.
定义输出函数
定义输出函数 stopIfAccuracyNotImproving(info,N)
,该函数在验证数据的最优分类准确度连续 N
次网络验证都没有提高时停止网络训练。此标准类似于使用验证损失的内置停止条件,只不过它适用于分类准确度而不是损失。
function stop = stopIfAccuracyNotImproving(info,N) stop = false; % Keep track of the best validation accuracy and the number of validations for which % there has not been an improvement of the accuracy. persistent bestValAccuracy persistent valLag % Clear the variables when training starts. if info.State == "start" bestValAccuracy = 0; valLag = 0; elseif ~isempty(info.ValidationLoss) % Compare the current validation accuracy to the best accuracy so far, % and either set the best accuracy to the current accuracy, or increase % the number of validations for which there has not been an improvement. if info.ValidationAccuracy > bestValAccuracy valLag = 0; bestValAccuracy = info.ValidationAccuracy; else valLag = valLag + 1; end % If the validation lag is at least N, that is, the validation accuracy % has not improved for at least N validations, then return true and % stop training. if valLag >= N stop = true; end end end
另请参阅
trainNetwork
| trainingOptions