Main Content

并行训练深度学习网络

此示例说明如何在本地计算机上运行多个深度学习试验。使用此示例作为模板,您可以修改网络层和训练选项,以满足您的具体应用需要。无论您有一个还是多个 GPU,都可以使用这种方法。如果您只有一个 GPU,网络会在后台逐个进行训练。本示例中的方法使您能够在进行深度学习试验时继续使用 MATLAB®。

您也可以使用 Experiment Manager 以交互方式并行训练多个深度网络。有关详细信息,请参阅Run Experiments in Parallel

准备数据集

在运行该示例之前,您必须能够访问深度学习数据集的本地副本。此示例使用的数据集包含从 0 到 9 的数字的合成图像。在以下代码中,将位置更改为指向您的数据集。

datasetLocation = fullfile(matlabroot,"toolbox","nnet", ...
    "nndemos","nndatasets","DigitDataset");

如果您要用更多资源运行试验,可以在云中的集群中运行此示例。

  • 将数据集上传到 Amazon S3 存储桶中。有关示例,请参阅在 AWS 中使用深度学习数据

  • 创建一个云集群。在 MATLAB 中,您可以直接通过 MATLAB 桌面在云中创建集群。有关详细信息,请参阅Create Cloud Cluster (Parallel Computing Toolbox)

  • 选择您的云集群作为默认集群,在主页选项卡的环境部分中,选择 Parallel > Select a Default Cluster

加载数据集

使用 imageDatastore 对象加载数据集。将数据集分成训练集、验证集和测试集。

imds = imageDatastore(datasetLocation, ...
 IncludeSubfolders=true, ...
 LabelSource="foldernames");

[imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.8,0.1);

要使用增强的图像数据训练网络,请创建 augmentedImageDatastore 对象。使用随机平移和水平翻转。数据增强有助于防止网络过拟合和记忆训练图像的具体细节。

imageSize = [28 28 1];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    RandXReflection=true, ...
    RandXTranslation=pixelRange, ...
    RandYTranslation=pixelRange);
augmentedImdsTrain = augmentedImageDatastore(imageSize,imdsTrain, ...
    DataAugmentation=imageAugmenter);

并行训练网络

启动一个工作进程数量与 GPU 数量一样多的并行池。您可以使用 gpuDeviceCount (Parallel Computing Toolbox) 函数检查可用 GPU 的数量。MATLAB 为每个工作进程分配一个不同的 GPU。默认情况下,parpool 使用您的默认集群配置文件。如果您没有更改默认值,parpool 将打开基于进程的池。此示例是用一台具有 2 个 GPU 的计算机运行的。

numGPUs = gpuDeviceCount("available");
parpool(numGPUs);
Starting parallel pool (parpool) using the 'Processes' profile ...
Connected to parallel pool with 2 workers.

要在训练期间从工作进程发送训练进度信息,请使用 parallel.pool.DataQueue (Parallel Computing Toolbox) 对象。要了解有关如何在训练期间使用数据队列获取反馈的详细信息,请参阅示例使用 parfeval 训练多个深度学习网络

dataqueue = parallel.pool.DataQueue;

定义网络层和训练选项。为了提高代码可读性,您可以在一个单独的函数中定义它们,该函数返回多个网络架构和训练选项。在本例中,networkLayersAndOptions 返回一个网络层元胞数组和一个训练选项数组,二者长度相同。在 MATLAB 中打开此示例,然后点击 networkLayersAndOptions 以打开支持函数 networkLayersAndOptions。粘贴到您自己的网络层和选项中。该文件包含示例训练选项,说明如何使用输出函数向数据队列发送信息。

[layersCell,options] = networkLayersAndOptions(augmentedImdsTrain,imdsValidation,dataqueue);

准备训练进度图,并设置回调函数以便在每个工作进程向队列发送数据后更新这些图。preparePlotsupdatePlots 是此示例的支持函数。

numExperiments = numel(layersCell);
handles = preparePlots(numExperiments);

afterEach(dataqueue,@(data) updatePlots(handles,data));

要在并行工作进程中保存计算结果,请使用 future 对象。为每次训练的结果预分配一个 future 对象数组。

trainingFuture(1:numExperiments) = parallel.FevalFuture;

使用 for 循环遍历网络层和选项,并使用 parfeval (Parallel Computing Toolbox) 在并行工作进程上训练网络。要从 trainnet 请求两个输出参量,请指定 2 作为 parfeval 的第二个输入参量。

for i=1:numExperiments
    trainingFuture(i) = parfeval(@trainnet,2,augmentedImdsTrain,layersCell{i},"crossentropy",options(i));
end

parfeval 不会阻止 MATLAB,因此您可以在计算的同时继续工作。

要从 future 对象中获取结果,请使用 fetchOutputs 函数。对于本示例,获取经过训练的网络及其训练信息。fetchOutputs 会阻止 MATLAB,直到结果可用为止。此步骤可能需要几分钟。

[network,trainingInfo] = fetchOutputs(trainingFuture);

使用 save 函数将结果保存到磁盘。要稍后再次加载结果,请使用 load 函数。使用 sprintfdatetime,按照当前日期时间命名文件。

filename = sprintf("experiment-%s",datetime("now",Format="yyyyMMdd-HHmmss"));
save(filename,"network","trainingInfo");

绘制结果

在网络完成训练后,使用 trainingInfo 中的信息绘制其训练进度。对于此示例,创建一行图来显示绘制的训练准确度对迭代及验证准确性的图。

t = tiledlayout(2,numExperiments);
title(t,"Training Progress Plots")

for i=1:numExperiments
    nexttile
    hold on; grid on;
    ylim([0 100]);
    plot(trainingInfo(i).TrainingHistory.Iteration,trainingInfo(i).TrainingHistory.Accuracy);
    plot(trainingInfo(i).ValidationHistory.Iteration,trainingInfo(i).ValidationHistory.Accuracy,".k",MarkerSize=10);
    xlabel("Iteration")
    ylabel("Accuracy")
end

然后,创建另一张图来显示绘制的训练损失对迭代及验证损失的图。

for i=1:numExperiments
    nexttile
    hold on; grid on;
    ylim([0 10]);
    plot(trainingInfo(i).TrainingHistory.Iteration,trainingInfo(i).TrainingHistory.Loss);
    plot(trainingInfo(i).ValidationHistory.Iteration,trainingInfo(i).ValidationHistory.Loss,".k",MarkerSize=10);
    xlabel("Iteration")
    ylabel("Loss")
end

选择网络后,您可以使用它对测试数据 imdsTest 中的图像进行分类。要使用多个观测值进行预测,请使用 minibatchpredict 函数。要将预测分数转换为标签,请使用 scores2label 函数。

另请参阅

| | | (Parallel Computing Toolbox) | | | |

相关示例

详细信息