Main Content

使用 parfeval 训练多个深度学习网络

此示例说明如何使用 parfeval 对深度学习网络的网络架构深度执行参数扫描,并在训练期间检索数据。

深度学习训练通常需要几小时或几天,搜寻良好的架构可能很困难。借助并行计算,您可以加快搜寻良好模型的速度并实现自动化。如果您可以使用具有多个图形处理单元 (GPU) 的计算机,则可以使用本地并行池在数据集的本地副本上完成此示例。如果要使用更多资源,可以将深度学习训练扩展到云。此示例说明如何使用 parfeval 在云集群中对网络架构的深度执行参数扫描。使用 parfeval 可以在后台进行训练而不会阻止 MATLAB,并提供可在结果令人满意时提前停止训练的选项。您可以修改脚本,以对其他任何参数执行参数扫描。此外,此示例还说明如何在计算期间使用 DataQueue 从工作进程获取反馈。

要求

您需要配置集群并将数据上传到云,才能运行此示例。在 MATLAB 中,您可以直接通过 MATLAB 桌面在云中创建集群。在主页选项卡上,在 Parallel 菜单中,选择 Create and Manage Clusters。在 Cluster Profile Manager 中,点击 Create Cloud Cluster。您也可以使用 MathWorks Cloud Center 来创建和访问计算集群。有关详细信息,请参阅 Cloud Center 快速入门。对于本示例,请确保在 MATLAB 主页选项卡的 Parallel > Select a Default Cluster 中将您的集群设置为默认集群。然后,将您的数据上传到 Amazon S3 存储桶并直接从 MATLAB 中使用它。此示例使用已存储在 Amazon S3 中的 CIFAR-10 数据集的副本。有关说明,请参阅在 AWS 中使用深度学习数据

从云中加载数据集

使用 imageDatastore 从云中加载训练数据集和测试数据集。将训练数据集拆分为训练数据集和验证数据集两部分,并保留测试数据集以测试基于参数扫描得到的最佳网络。在本示例中,您使用存储在 Amazon S3 中的 CIFAR-10 数据集的副本。为确保工作进程能够访问云中的数据存储,请确保已正确设置 AWS 凭据的环境变量。请参阅在 AWS 中使用深度学习数据

imds = imageDatastore("s3://cifar10cloud/cifar10/train", ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

imdsTest = imageDatastore("s3://cifar10cloud/cifar10/test", ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);

通过创建 augmentedImageDatastore 对象,用增强的图像数据训练网络。使用随机平移和水平翻转。数据增强有助于防止网络过拟合和记忆训练图像的具体细节。

imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    RandXReflection=true, ...
    RandXTranslation=pixelRange, ...
    RandYTranslation=pixelRange);
augmentedImdsTrain=augmentedImageDatastore(imageSize,imdsTrain, ...
    DataAugmentation=imageAugmenter, ...
    OutputSizeMode="randcrop");

同时训练多个网络

指定训练选项。设置小批量大小并根据小批量大小线性缩放初始学习率。设置验证频率,使 trainnet 每轮训练都验证一次网络。

miniBatchSize = 128;
initialLearnRate = 1e-1 * miniBatchSize/256;
validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize);
options = trainingOptions("sgdm", ...
    MiniBatchSize=miniBatchSize, ... % Set the mini-batch size
    Verbose=false, ... % Do not send command line output.
    Metrics="accuracy", ...
    InitialLearnRate=initialLearnRate, ... % Set the scaled learning rate.
    L2Regularization=1e-10, ...
    MaxEpochs=30, ...
    Shuffle="every-epoch", ...
    ValidationData=imdsValidation, ...
    ValidationFrequency=validationFrequency);

指定要对其执行参数扫描的网络架构的深度。使用 parfeval 同时对多个网络执行并行参数扫描训练。在扫描中使用循环来迭代不同的网络架构。在脚本末尾创建辅助函数 createNetworkArchitecture,它接受输入参量来控制网络的深度并为 CIFAR-10 创建一个架构。使用 parfeval 将由 trainnet 执行的计算量分散到集群中的工作进程。parfeval 将返回一个 future 变量,以便在计算完成后存储经过训练的网络和训练信息。

默认情况下,trainnet 函数使用 GPU(如果有)。在 GPU 上进行训练需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。有关受支持设备的信息,请参阅GPU 计算要求 (Parallel Computing Toolbox)。否则,trainnet 函数使用 CPU。要指定执行环境,请使用 ExecutionEnvironment 训练选项。

netDepths = 1:4;
numExperiments = numel(netDepths);
for idx = 1:numExperiments
    networksFuture(idx) = parfeval(@trainnet,2, ...
        augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),"crossentropy",options);
end
Starting parallel pool (parpool) using the 'MyCluster' profile ...
Connected to parallel pool with 4 workers (PreferredPoolNumWorkers).

parfeval 不会阻止 MATLAB,这意味着您可以继续执行命令。在本例中,通过对 networksFuture 使用 fetchOutputs 获取经过训练的网络及其训练信息。fetchOutputs 函数会等待直至 future 变量完成执行。

[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);

通过访问 trainingInfo 结构体获得网络的最终验证准确度。

for idx = 1:numExperiments
    validationHistory = trainingInfo(idx).ValidationHistory;
    accuracies(idx) = validationHistory.Accuracy(end);
end

accuracies
accuracies = 1×4

   70.7200   78.8200   76.1000   78.0200

选择准确度最好的网络。

[~, I] = max(accuracies);
bestNetwork = trainedNetworks(I(1));

根据测试数据集测试网络性能。要使用多个观测值进行预测,请使用 minibatchpredict 函数。要将预测分数转换为标签,请使用 scores2label 函数。minibatchpredict 函数自动使用 GPU(如果有)。

classNames = categories(imdsTest.Labels);
scores = minibatchpredict(bestNetwork,imdsTest);
Y = scores2label(scores,classNames);
accuracy = sum(Y == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy = 0.7798

计算测试数据的混淆矩阵。

figure
confusionchart(imdsTest.Labels,Y,RowSummary="row-normalized",ColumnSummary="column-normalized");

在训练过程中发送反馈数据

准备并初始化显示每个工作进程中的训练进度的绘图。使用 animatedLine 可以方便地显示变化的数据。

f = figure;
f.Visible = true;
for i=1:4
    subplot(2,2,i)
    xlabel("Iteration");
    ylabel("Training accuracy");
    lines(i) = animatedline;
end

使用 DataQueue 将工作进程中的训练进度数据发送给客户端,然后对数据绘图。使用 afterEach 在每次工作进程发送训练进度反馈时更新绘图。参数 opts 包含有关工作进程、训练迭代和训练准确度的信息。

D = parallel.pool.DataQueue;
afterEach(D, @(opts) updatePlot(lines,opts{:}));

指定要对其执行参数扫描的网络架构的深度,并使用 parfeval 执行并行参数扫描。通过将脚本作为附加文件添加到当前池中,使工作进程能访问此脚本中的任何辅助函数。在训练选项中定义一个输出函数,用于将工作进程中的训练进度发送到客户端。训练选项依赖于工作进程的索引,这些选项必须包含在 for 循环中。

netDepths = 1:4;
addAttachedFiles(gcp,mfilename);
for idx = 1:numel(netDepths)
    
    miniBatchSize = 128;
    initialLearnRate = 1e-1 * miniBatchSize/256; % Scale the learning rate according to the mini-batch size.
    validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize);
    
    options = trainingOptions("sgdm", ...
        OutputFcn=@(state) sendTrainingProgress(D,idx,state), ... % Set an output function to send intermediate results to the client.
        MiniBatchSize=miniBatchSize, ... % Set the corresponding MiniBatchSize in the sweep.
        Verbose=false, ... % Do not send command line output.
        InitialLearnRate=initialLearnRate, ... % Set the scaled learning rate.
        Metrics="accuracy", ...
        L2Regularization=1e-10, ...
        MaxEpochs=30, ...
        Shuffle="every-epoch", ...
        ValidationData=imdsValidation, ...
        ValidationFrequency=validationFrequency);
    
    networksFuture(idx) = parfeval(@trainnet,2, ...
        augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),"crossentropy",options);
end

parfeval 对集群中的工作进程调用 trainnet。计算在后台进行,因此您可以继续在 MATLAB 中工作。如果您要停止某项 parfeval 计算,可以对其对应的 future 变量调用 cancel。例如,如果您观察到网络性能不佳,您可以取消其 future 变量的执行。如果您执行了此操作,则下一个排队的 future 变量将开始其计算过程。

在本例中,通过对 future 变量调用 fetchOutputs,获取经过训练的网络及其训练信息。

[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);

获取每个网络的最终验证准确度。

for idx = 1:numExperiments
    validationHistory = trainingInfo(idx).ValidationHistory;
    accuracies(idx) = validationHistory.Accuracy(end);
end

accuracies
accuracies = 1×4

   71.4600   78.3600   74.4000   79.3800

辅助函数

使用一个函数为 CIFAR-10 数据集定义网络架构,并使用输入参量调整网络深度。为了简化代码,使用对输入进行卷积的卷积块。池化层对空间维度进行下采样。

function layers = createNetworkArchitecture(netDepth)
imageSize = [32 32 3];
netWidth = round(16/sqrt(netDepth)); % netWidth controls the number of filters in a convolutional block

layers = [
    imageInputLayer(imageSize)
    
    convolutionalBlock(netWidth,netDepth)
    maxPooling2dLayer(2,Stride=2)
    convolutionalBlock(2*netWidth,netDepth)
    maxPooling2dLayer(2,Stride=2)
    convolutionalBlock(4*netWidth,netDepth)
    averagePooling2dLayer(8)
    
    fullyConnectedLayer(10)
    softmaxLayer
    ];
end

定义一个函数,以便在网络架构中创建卷积模块。

function layers = convolutionalBlock(numFilters,numConvLayers)
layers = [
    convolution2dLayer(3,numFilters,Padding="same")
    batchNormalizationLayer
    reluLayer
    ];

layers = repmat(layers,numConvLayers,1);
end

定义一个函数,以通过 DataQueue 将训练进度发送到客户端。

function stop = sendTrainingProgress(D,idx,info)
if info.State == "iteration" && ~isempty(info.TrainingAccuracy)
    send(D,{idx,info.Iteration,info.TrainingAccuracy});
end
stop = false;
end

定义一个更新函数,以在工作进程发送中间结果时更新绘图。

function updatePlot(lines,idx,iter,acc)
addpoints(lines(idx),iter,acc);
drawnow limitrate nocallbacks
end

另请参阅

(Parallel Computing Toolbox) | | | | |

相关主题