本页对应的英文页面已更新,但尚未翻译。 若要查看最新内容,请点击此处访问英文页面。

监控深度学习训练进度

在训练深度学习网络时,监控训练进度通常很有用。通过在训练过程中绘制各种指标,您可以了解训练的进度情况。例如,您可以确定网络准确度是否改善以及改善速度,还可以确定网络是否开始过拟合训练数据。

trainingOptions 中将 'training-progress' 指定为 'Plots' 值并开始网络训练时,trainNetwork 会创建一个图窗并在每次迭代时显示训练指标。 每次迭代都是对梯度的一次估计和对网络参数的一次更新。如果在 trainingOptions 中指定验证数据,则每次 trainNetwork 验证网络时,该图窗都会显示验证指标。该图窗绘制以下内容:

  • 训练准确度 - 针对每个小批量的分类准确度。

  • 经过平滑处理的训练准确度 - 经过平滑处理的训练准确度,通过将平滑算法应用于训练准确度来获得。它的噪声低于未平滑的准确度,因此更易于揭示趋势。

  • 验证准确度 - 针对整个验证集的分类准确度(使用 trainingOptions 指定)。

  • 训练损失经过平滑处理的训练损失验证损失 - 分别指每个小批量的损失、其经过平滑处理的版本以及验证集的损失。如果网络的最终层是一个 classificationLayer,则损失函数是交叉熵损失。有关分类和回归问题的损失函数的详细信息,请参阅Output Layers

对于回归网络,该图窗绘制均方根误差 (RMSE) 而不是准确度。

图窗使用交替底色来标记每一轮训练。一轮训练是对整个训练数据集的一次完整遍历。

在训练过程中,您可以通过点击右上角的停止按钮停止训练并返回网络的当前状态。例如,您可能希望在网络准确度达到稳定水平并且准确度明显不再提高时停止训练。点击停止按钮后,可能需要一段时间才能完成训练。训练完成后,trainNetwork 将返回经过训练的网络。

训练结束后,查看结果,其中显示最终验证准确度和训练结束的原因。最终验证指标在绘图中标记为 Final。如果您的网络包含批量归一化层,则最终验证指标通常与训练过程中评估出的验证指标不同。这是因为最终网络中的批量归一化层执行的操作与训练过程中执行的操作不同。

在右侧,查看有关训练时间和设置的信息。要了解有关训练选项的详细信息,请参阅Set Up Parameters and Train Convolutional Neural Network

在训练过程中绘制训练进度

训练网络并在训练过程中绘制训练进度。

加载训练数据,其中包含 5000 个数字图像。留出 1000 个图像用于网络验证。

[XTrain,YTrain] = digitTrain4DArrayData;

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

构建网络以对数字图像数据进行分类。

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

指定网络训练的选项。要在训练过程中按固定时间间隔验证网络,请指定验证数据。选择 'ValidationFrequency' 值,以使网络大致在每轮训练都被验证一次。要在训练过程中绘制训练进度,请将 'training-progress' 指定为 'Plots' 值。

options = trainingOptions('sgdm', ...
    'MaxEpochs',8, ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',30, ...
    'Verbose',false, ...
    'Plots','training-progress');

训练网络。

net = trainNetwork(XTrain,YTrain,layers,options);

另请参阅

|

相关主题