Main Content

使用 parfor 训练多个深度学习网络

此示例说明如何使用 parfor 循环对训练选项执行参数扫描。

深度学习训练通常需要几小时或几天,搜寻良好的训练选项可能很困难。借助并行计算,您可以加快搜寻良好模型的速度并实现自动化。如果您可以使用具有多个图形处理单元 (GPU) 的计算机,则可以使用本地 parpool 在数据集的本地副本上完成此示例。如果要使用更多资源,可以将深度学习训练扩展到云。此示例说明如何使用 parfor 循环在云集群中对训练选项 MiniBatchSize 执行参数扫描。您可以修改此脚本,以对其他任何训练选项执行参数扫描。此外,此示例还说明如何在计算期间使用 DataQueue 从工作进程获取反馈。您还可以将脚本作为批处理作业发送给集群,这样您可以继续工作或者关闭 MATLAB,稍后再获取结果。有关详细信息,请参阅将深度学习批处理作业发送到集群

要求

您需要配置集群并将数据上传到云,才能运行此示例。在 MATLAB 中,您可以直接通过 MATLAB 桌面在云中创建集群。在主页选项卡上,在 Parallel 菜单中,选择 Create and Manage Clusters。在 Cluster Profile Manager 中,点击 Create Cloud Cluster。您也可以使用 MathWorks Cloud Center 来创建和访问计算集群。有关详细信息,请参阅 Cloud Center 快速入门。对于本示例,请确保在 MATLAB 主页选项卡的 Parallel > Select a Default Cluster 中将您的集群设置为默认集群。然后,将您的数据上传到 Amazon S3 存储桶并直接从 MATLAB 中使用它。此示例使用已存储在 Amazon S3 中的 CIFAR-10 数据集的副本。有关说明,请参阅在 AWS 中使用深度学习数据

从云中加载数据集

使用 imageDatastore 从云中加载训练数据集和测试数据集。将训练数据集拆分为训练数据集和验证数据集两部分,并保留测试数据集以测试基于参数扫描得到的最佳网络。查看数据集的类名称。在本示例中,您使用存储在 Amazon S3 中的 CIFAR-10 数据集的副本。为确保工作进程能够访问云中的数据存储,请确保已正确设置 AWS 凭据的环境变量。请参阅在 AWS 中使用深度学习数据

imds = imageDatastore("s3://cifar10cloud/cifar10/train", ...
     IncludeSubfolders=true, ...
     LabelSource="foldernames");
 
imdsTest = imageDatastore("s3://cifar10cloud/cifar10/test", ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);
classNames = categories(imdsTrain.Labels)
classNames = 10×1 cell
    {'airplane'  }
    {'automobile'}
    {'bird'      }
    {'cat'       }
    {'deer'      }
    {'dog'       }
    {'frog'      }
    {'horse'     }
    {'ship'      }
    {'truck'     }

通过创建 augmentedImageDatastore 对象,用增强的图像数据训练网络。使用随机平移和水平翻转。数据增强有助于防止网络过拟合和记忆训练图像的具体细节。

imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    RandXReflection=true, ...
    RandXTranslation=pixelRange, ...
    RandYTranslation=pixelRange);
augmentedImdsTrain = augmentedImageDatastore(imageSize,imdsTrain, ...
    DataAugmentation=imageAugmenter, ...
    OutputSizeMode="randcrop");

定义网络架构

为 CIFAR-10 数据集定义一个网络架构。为了简化代码,使用对输入进行卷积的卷积块。池化层对空间维度进行下采样。

imageSize = [32 32 3];
netDepth = 2; % netDepth controls the depth of a convolutional block
netWidth = 16; % netWidth controls the number of filters in a convolutional block

layers = [
    imageInputLayer(imageSize)
    
    convolutionalBlock(netWidth,netDepth)
    maxPooling2dLayer(2,Stride=2)
    convolutionalBlock(2*netWidth,netDepth)
    maxPooling2dLayer(2,Stride=2)
    convolutionalBlock(4*netWidth,netDepth)
    averagePooling2dLayer(8)
    
    fullyConnectedLayer(10)
    softmaxLayer
    ];

同时训练多个网络

指定要对其执行参数扫描的小批量大小。为得到的网络和准确度分配变量。

miniBatchSizes = [64 128 256 512];
numMiniBatchSizes = numel(miniBatchSizes);
trainedNetworks = cell(numMiniBatchSizes,1);
accuracies = zeros(numMiniBatchSizes,1);

parfor 循环内基于不同小批量大小对多个网络执行并行参数扫描训练。集群中的工作进程同时训练这些网络,并在训练完成后将训练过的网络和准确度发送回去。如果您要检查训练是否正常工作,请在训练选项中将 Verbose 设置为 true。请注意,各工作进程独立进行计算,因此命令行输出与迭代的顺序不同。

parfor idx = 1:numMiniBatchSizes
    
    miniBatchSize = miniBatchSizes(idx);
    initialLearnRate = 1e-1 * miniBatchSize/256; % Scale the learning rate according to the mini-batch size.
    
    % Define the training options. Set the mini-batch size.
    options = trainingOptions("sgdm", ...
        MiniBatchSize=miniBatchSize, ... % Set the corresponding MiniBatchSize in the sweep.
        Verbose=false, ... % Do not send command line output.
        InitialLearnRate=initialLearnRate, ... % Set the scaled learning rate.
        Metrics="accuracy", ...
        L2Regularization=1e-10, ...
        MaxEpochs=30, ...
        Shuffle="every-epoch", ...
        ValidationData=imdsValidation, ...
        LearnRateSchedule="piecewise", ...
        LearnRateDropFactor=0.1, ...
        LearnRateDropPeriod=25);
    
    % Train the network in a worker in the cluster.
    net = trainnet(augmentedImdsTrain,layers,"crossentropy",options);
    
    % To obtain the accuracy of this network, use the trained network to
    % classify the validation images on the worker and compare the predicted labels to the
    % actual labels.
    scores = minibatchpredict(net,imdsValidation);
    Y = scores2label(scores,classNames);
    accuracies(idx) = sum(Y == imdsValidation.Labels)/numel(imdsValidation.Labels);
    
    % Send the trained network back to the client.
    trainedNetworks{idx} = net;
end
Starting parallel pool (parpool) using the 'MyClusterInTheCloud' profile ...
Connected to parallel pool with 4 workers (PreferredPoolNumWorkers).

parfor 完成后,trainedNetworks 将包含工作进程训练得出的网络。显示经过训练的网络及其准确度。

trainedNetworks
trainedNetworks=4×1 cell array
    {1×1 dlnetwork}
    {1×1 dlnetwork}
    {1×1 dlnetwork}
    {1×1 dlnetwork}

accuracies
accuracies = 4×1

    0.8404
    0.8378
    0.8374
    0.8346

选择准确度最好的网络。根据测试数据集测试网络性能。

[~, I] = max(accuracies);
bestNetwork = trainedNetworks{I(1)};

scores = minibatchpredict(bestNetwork,imdsTest);
Y = scores2label(scores,classNames);
accuracy = sum(Y == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy = 0.8374

在训练过程中发送反馈数据

准备并初始化显示每个工作进程中的训练进度的绘图。使用 animatedLine 可以方便地显示变化的数据。

f = figure;
f.Visible = true;
for i=1:4
    subplot(2,2,i)
    xlabel('Iteration');
    ylabel('Training accuracy');
    lines(i) = animatedline;
end

使用 DataQueue 将工作进程中的训练进度数据发送给客户端,然后对数据绘图。使用 afterEach 在每次工作进程发送训练进度反馈时更新绘图。参数 opts 包含有关工作进程、训练迭代和训练准确度的信息。

D = parallel.pool.DataQueue;
afterEach(D, @(opts) updatePlot(lines, opts{:}));

在 parfor 循环内基于不同小批量大小对多个网络执行并行参数扫描训练。请注意,在训练选项中使用 OutputFcn 可在每次迭代时将训练进度发送到客户端。

parfor idx = 1:numMiniBatchSizes
    
    miniBatchSize = miniBatchSizes(idx);
    initialLearnRate = 1e-1 * miniBatchSize/256; % Scale the learning rate according to the miniBatchSize.
    
    % Define the training options. Set an output function to send data back
    % to the client each iteration.
    options = trainingOptions("sgdm", ...
        MiniBatchSize=miniBatchSize, ... % Set the corresponding MiniBatchSize in the sweep.
        Verbose=false, ... % Do not send command line output.
        InitialLearnRate=initialLearnRate, ... % Set the scaled learning rate.
        OutputFcn=@(state) sendTrainingProgress(D,idx,state), ... % Set an output function to send intermediate results to the client.
        Metrics="accuracy", ...
        L2Regularization=1e-10, ...
        MaxEpochs=30, ...
        Shuffle="every-epoch", ...
        ValidationData=imdsValidation, ...
        LearnRateSchedule="piecewise", ...
        LearnRateDropFactor=0.1, ...
        LearnRateDropPeriod=25);
    
    % Train the network in a worker in the cluster. The workers send
    % training progress information during training to the client.
    net = trainnet(augmentedImdsTrain,layers,"crossentropy",options);
    
    % To obtain the accuracy of this network, use the trained network to
    % classify the validation images on the worker and compare the predicted labels to the
    % actual labels.
    scores = minibatchpredict(net,imdsValidation);
    Y = scores2label(scores,classNames);
    accuracies(idx) = sum(Y == imdsValidation.Labels)/numel(imdsValidation.Labels);
    
    % Send the trained network back to the client.
    trainedNetworks{idx} = net;
end

parfor 完成后,trainedNetworks 将包含工作进程训练得出的网络。显示经过训练的网络及其准确度。

trainedNetworks
trainedNetworks=4×1 cell array
    {1×1 dlnetwork}
    {1×1 dlnetwork}
    {1×1 dlnetwork}
    {1×1 dlnetwork}

accuracies
accuracies = 4×1

    0.8388
    0.8288
    0.8326
    0.8226

选择准确度最好的网络。根据测试数据集测试网络性能。

[~, I] = max(accuracies);
bestNetwork = trainedNetworks{I(1)};

scores = minibatchpredict(bestNetwork,imdsTest);
Y = scores2label(scores,classNames);
accuracy = sum(Y == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy = 0.8352

辅助函数

定义一个函数,以便在网络架构中创建卷积模块。

function layers = convolutionalBlock(numFilters,numConvLayers)
layers = [
    convolution2dLayer(3,numFilters,Padding="same")
    batchNormalizationLayer
    reluLayer
    ];

layers = repmat(layers,numConvLayers,1);
end

定义一个函数,以通过 DataQueue 将训练进度发送到客户端。

function stop = sendTrainingProgress(D,idx,info)
if info.State == "iteration" && ~isempty(info.TrainingAccuracy)
    send(D,{idx,info.Iteration,info.TrainingAccuracy});
end
stop = false;
end

定义一个更新函数,以在工作进程发送中间结果时更新绘图。

function updatePlot(lines,idx,iter,acc)
addpoints(lines(idx),iter,acc);
drawnow limitrate nocallbacks
end

另请参阅

| | | (Parallel Computing Toolbox) |

相关示例

详细信息