使用 parfeval
训练多个深度学习网络
此示例说明如何使用 parfeval
对深度学习网络的网络架构深度执行参数扫描,并在训练期间检索数据。
深度学习训练通常需要几小时或几天,搜寻良好的架构可能很困难。借助并行计算,您可以加快搜寻良好模型的速度并实现自动化。如果您可以使用具有多个图形处理单元 (GPU) 的计算机,则可以使用本地并行池在数据集的本地副本上完成此示例。如果要使用更多资源,可以将深度学习训练扩展到云。此示例说明如何使用 parfeval
在云集群中对网络架构的深度执行参数扫描。使用 parfeval
可以在后台进行训练而不会阻止 MATLAB,并提供可在结果令人满意时提前停止训练的选项。您可以修改脚本,以对其他任何参数执行参数扫描。此外,此示例还说明如何在计算期间使用 DataQueue
从工作进程获取反馈。
要求
您需要配置集群并将数据上传到云,才能运行此示例。在 MATLAB 中,您可以直接通过 MATLAB 桌面在云中创建集群。在主页选项卡上,在 Parallel 菜单中,选择 Create and Manage Clusters。在 Cluster Profile Manager 中,点击 Create Cloud Cluster。您也可以使用 MathWorks Cloud Center 来创建和访问计算集群。有关详细信息,请参阅云中心快速入门。对于本示例,请确保在 MATLAB 主页选项卡的 Parallel > Select a Default Cluster 中将您的集群设置为默认集群。然后,将您的数据上传到 Amazon S3 存储桶并直接从 MATLAB 中使用它。此示例使用已存储在 Amazon S3 中的 CIFAR-10 数据集的副本。有关说明,请参阅在 AWS 中使用深度学习数据。
从云中加载数据集
使用 imageDatastore
从云中加载训练数据集和测试数据集。将训练数据集拆分为训练数据集和验证数据集两部分,并保留测试数据集以测试基于参数扫描得到的最佳网络。在本示例中,您使用存储在 Amazon S3 中的 CIFAR-10 数据集的副本。为确保工作进程能够访问云中的数据存储,请确保已正确设置 AWS 凭据的环境变量。请参阅在 AWS 中使用深度学习数据。
imds = imageDatastore('s3://cifar10cloud/cifar10/train', ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames'); imdsTest = imageDatastore('s3://cifar10cloud/cifar10/test', ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames'); [imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);
通过创建 augmentedImageDatastore
对象,用增强的图像数据训练网络。使用随机平移和水平翻转。数据增强有助于防止网络过拟合和记忆训练图像的具体细节。
imageSize = [32 32 3]; pixelRange = [-4 4]; imageAugmenter = imageDataAugmenter( ... 'RandXReflection',true, ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange); augmentedImdsTrain = augmentedImageDatastore(imageSize,imdsTrain, ... 'DataAugmentation',imageAugmenter, ... 'OutputSizeMode','randcrop');
同时训练多个网络
定义训练选项。设置小批量大小并根据小批量大小线性缩放初始学习率。设置验证频率,使 trainNetwork
每轮训练都验证一次网络。
miniBatchSize = 128; initialLearnRate = 1e-1 * miniBatchSize/256; validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize); options = trainingOptions('sgdm', ... 'MiniBatchSize',miniBatchSize, ... % Set the mini-batch size 'Verbose',false, ... % Do not send command line output. 'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate. 'L2Regularization',1e-10, ... 'MaxEpochs',30, ... 'Shuffle','every-epoch', ... 'ValidationData',imdsValidation, ... 'ValidationFrequency', validationFrequency);
指定要对其执行参数扫描的网络架构的深度。使用 parfeval
同时对多个网络执行并行参数扫描训练。在扫描中使用循环来迭代不同的网络架构。在脚本末尾创建辅助函数 createNetworkArchitecture
,它接受输入参数来控制网络的深度并为 CIFAR-10 创建一个架构。使用 parfeval
将由 trainNetwork
执行的计算量分散到集群中的工作进程。parfeval
将返回一个 future 变量,以便在计算完成后存储经过训练的网络和训练信息。
netDepths = 1:4; for idx = 1:numel(netDepths) networksFuture(idx) = parfeval(@trainNetwork,2, ... augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),options); end
Starting parallel pool (parpool) using the 'MyCluster' profile ... Connected to the parallel pool (number of workers: 4).
parfeval
不会阻止 MATLAB,这意味着您可以继续执行命令。在本例中,通过对 networksFuture
使用 fetchOutputs
获取经过训练的网络及其训练信息。fetchOutputs
函数会等待直至 future 变量完成执行。
[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);
通过访问 trainingInfo
结构体获得网络的最终验证准确度。
accuracies = [trainingInfo.FinalValidationAccuracy]
accuracies = 1×4
72.5600 77.2600 79.4000 78.6800
选择准确度最好的网络。根据测试数据集测试网络性能。
[~, I] = max(accuracies); bestNetwork = trainedNetworks(I(1)); YPredicted = classify(bestNetwork,imdsTest); accuracy = sum(YPredicted == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy = 0.7840
计算测试数据的混淆矩阵。
figure('Units','normalized','Position',[0.2 0.2 0.4 0.4]); confusionchart(imdsTest.Labels,YPredicted,'RowSummary','row-normalized','ColumnSummary','column-normalized');
在训练过程中发送反馈数据
准备并初始化显示每个工作进程中的训练进度的绘图。使用 animatedLine
可以方便地显示变化的数据。
f = figure; f.Visible = true; for i=1:4 subplot(2,2,i) xlabel('Iteration'); ylabel('Training accuracy'); lines(i) = animatedline; end
使用 DataQueue
将工作进程中的训练进度数据发送给客户端,然后对数据绘图。使用 afterEach
在每次工作进程发送训练进度反馈时更新绘图。参数 opts
包含有关工作进程、训练迭代和训练准确度的信息。
D = parallel.pool.DataQueue; afterEach(D, @(opts) updatePlot(lines, opts{:}));
指定要对其执行参数扫描的网络架构的深度,并使用 parfeval
执行并行参数扫描。通过将脚本作为附加文件添加到当前池中,使工作进程能访问此脚本中的任何辅助函数。在训练选项中定义一个输出函数,用于将工作进程中的训练进度发送到客户端。训练选项依赖于工作进程的索引,这些选项必须包含在 for
循环中。
netDepths = 1:4; addAttachedFiles(gcp,mfilename); for idx = 1:numel(netDepths) miniBatchSize = 128; initialLearnRate = 1e-1 * miniBatchSize/256; % Scale the learning rate according to the mini-batch size. validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize); options = trainingOptions('sgdm', ... 'OutputFcn',@(state) sendTrainingProgress(D,idx,state), ... % Set an output function to send intermediate results to the client. 'MiniBatchSize',miniBatchSize, ... % Set the corresponding MiniBatchSize in the sweep. 'Verbose',false, ... % Do not send command line output. 'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate. 'L2Regularization',1e-10, ... 'MaxEpochs',30, ... 'Shuffle','every-epoch', ... 'ValidationData',imdsValidation, ... 'ValidationFrequency', validationFrequency); networksFuture(idx) = parfeval(@trainNetwork,2, ... augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),options); end
parfeval
对集群中的工作进程调用 trainNetwork
。计算在后台进行,因此您可以继续在 MATLAB 中工作。如果您要停止某项 parfeval
计算,可以对其对应的 future 变量调用 cancel
。例如,如果您观察到网络性能不佳,您可以取消其 future 变量的执行。如果您执行了此操作,则下一个排队的 future 变量将开始其计算过程。
在本例中,通过对 future 变量调用 fetchOutputs
,获取经过训练的网络及其训练信息。
[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);
获取每个网络的最终验证准确度。
accuracies = [trainingInfo.FinalValidationAccuracy]
accuracies = 1×4
72.9200 77.4800 76.9200 77.0400
辅助函数
使用一个函数为 CIFAR-10 数据集定义网络架构,并使用输入参数调整网络深度。为了简化代码,使用对输入进行卷积的卷积块。池化层对空间维度进行下采样。
function layers = createNetworkArchitecture(netDepth) imageSize = [32 32 3]; netWidth = round(16/sqrt(netDepth)); % netWidth controls the number of filters in a convolutional block layers = [ imageInputLayer(imageSize) convolutionalBlock(netWidth,netDepth) maxPooling2dLayer(2,'Stride',2) convolutionalBlock(2*netWidth,netDepth) maxPooling2dLayer(2,'Stride',2) convolutionalBlock(4*netWidth,netDepth) averagePooling2dLayer(8) fullyConnectedLayer(10) softmaxLayer classificationLayer ]; end
定义一个函数,以便在网络架构中创建卷积模块。
function layers = convolutionalBlock(numFilters,numConvLayers) layers = [ convolution2dLayer(3,numFilters,'Padding','same') batchNormalizationLayer reluLayer ]; layers = repmat(layers,numConvLayers,1); end
定义一个函数,以通过 DataQueue
将训练进度发送到客户端。
function sendTrainingProgress(D,idx,info) if info.State == "iteration" send(D,{idx,info.Iteration,info.TrainingAccuracy}); end end
定义一个更新函数,以在工作进程发送中间结果时更新绘图。
function updatePlot(lines,idx,iter,acc) addpoints(lines(idx),iter,acc); drawnow limitrate nocallbacks end
另请参阅
parfeval
(Parallel Computing Toolbox) | afterEach
| trainNetwork
| trainingOptions
| imageDatastore