并行训练深度学习网络
此示例说明如何在本地计算机上运行多个深度学习试验。使用此示例作为模板,您可以修改网络层和训练选项,以满足您的具体应用需要。无论您有一个还是多个 GPU,都可以使用这种方法。如果您只有一个 GPU,网络会在后台逐个进行训练。本示例中的方法使您能够在进行深度学习试验时继续使用 MATLAB®。
您也可以使用 Experiment Manager 以交互方式并行训练多个深度网络。有关详细信息,请参阅Use Experiment Manager to Train Networks in Parallel。
准备数据集
在运行该示例之前,您必须能够访问深度学习数据集的本地副本。此示例使用的数据集包含从 0 到 9 的数字的合成图像。在以下代码中,将位置更改为指向您的数据集。
datasetLocation = fullfile(matlabroot,'toolbox','nnet', ... 'nndemos','nndatasets','DigitDataset');
如果您要用更多资源运行试验,可以在云中的集群中运行此示例。
将数据集上传到 Amazon S3 存储桶中。有关示例,请参阅在 AWS 中使用深度学习数据。
创建一个云集群。在 MATLAB 中,您可以直接通过 MATLAB 桌面在云中创建集群。有关详细信息,请参阅Create Cloud Cluster (Parallel Computing Toolbox)。
选择您的云集群作为默认集群,在主页选项卡的环境部分中,选择 Parallel > Select a Default Cluster。
加载数据集
使用 imageDatastore
对象加载数据集。将数据集分成训练集、验证集和测试集。
imds = imageDatastore(datasetLocation, ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames'); [imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.8,0.1);
要使用增强的图像数据训练网络,请创建 augmentedImageDatastore
对象。使用随机平移和水平翻转。数据增强有助于防止网络过拟合和记忆训练图像的具体细节。
imageSize = [28 28 1]; pixelRange = [-4 4]; imageAugmenter = imageDataAugmenter( ... 'RandXReflection',true, ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange); augmentedImdsTrain = augmentedImageDatastore(imageSize,imdsTrain, ... 'DataAugmentation',imageAugmenter);
并行训练网络
启动一个工作进程数量与 GPU 数量一样多的并行池。您可以使用 gpuDeviceCount
(Parallel Computing Toolbox) 函数检查可用 GPU 的数量。MATLAB 为每个工作进程分配一个不同的 GPU。默认情况下,parpool
使用您的默认集群配置文件。如果您没有更改默认值,则该配置文件为 local
。此示例是用一台具有 2 个 GPU 的计算机运行的。
numGPUs = gpuDeviceCount("available");
parpool(numGPUs);
Starting parallel pool (parpool) using the 'Processes' profile ... Connected to the parallel pool (number of workers: 2).
要在训练期间从工作进程发送训练进度信息,请使用 parallel.pool.DataQueue
(Parallel Computing Toolbox) 对象。要了解有关如何在训练期间使用数据队列获取反馈的详细信息,请参阅示例使用 parfeval 训练多个深度学习网络。
dataqueue = parallel.pool.DataQueue;
定义网络层和训练选项。为了提高代码可读性,您可以在一个单独的函数中定义它们,该函数返回多个网络架构和训练选项。在本例中,networkLayersAndOptions
返回一个网络层元胞数组和一个训练选项数组,二者长度相同。在 MATLAB 中打开此示例,然后点击 networkLayersAndOptions
以打开支持函数 networkLayersAndOptions
。粘贴到您自己的网络层和选项中。该文件包含示例训练选项,说明如何使用输出函数向数据队列发送信息。
[layersCell,options] = networkLayersAndOptions(augmentedImdsTrain,imdsValidation,dataqueue);
准备训练进度图,并设置回调函数以便在每个工作进程向队列发送数据后更新这些图。preparePlots
和 updatePlots
是此示例的支持函数。
handles = preparePlots(numel(layersCell));
afterEach(dataqueue,@(data) updatePlots(handles,data));
要在并行工作进程中保存计算结果,请使用 future 对象。为每次训练的结果预分配一个 future 对象数组。
trainingFuture(1:numel(layersCell)) = parallel.FevalFuture;
使用 for
循环遍历网络层和选项,并使用 parfeval
(Parallel Computing Toolbox) 在并行工作进程上训练网络。要从 trainNetwork
请求两个输出参数,请指定 2
作为 parfeval
的第二个输入参数。
for i=1:numel(layersCell) trainingFuture(i) = parfeval(@trainNetwork,2,augmentedImdsTrain,layersCell{i},options(i)); end
parfeval
不会阻止 MATLAB,因此您可以在计算的同时继续工作。
要从 future 对象中获取结果,请使用 fetchOutputs
函数。对于本示例,获取经过训练的网络及其训练信息。fetchOutputs
会阻止 MATLAB,直到结果可用为止。此步骤可能需要几分钟。
[network,trainingInfo] = fetchOutputs(trainingFuture);
使用 save
函数将结果保存到磁盘。要稍后再次加载结果,请使用 load
函数。使用 sprintf
和 datetime
,按照当前日期时间命名文件。
filename = sprintf('experiment-%s',datetime('now','Format','yyyyMMdd''T''HHmmss')); save(filename,'network','trainingInfo');
绘制结果
在网络完成训练后,使用 trainingInfo
中的信息绘制其训练进度。
使用子图为每个网络分发不同绘图。对于本示例,使用第一行子图绘制训练准确度对轮次编号及验证准确度的图。
figure('Units','normalized','Position',[0.1 0.1 0.6 0.6]); title('Training Progress Plots'); for i=1:numel(layersCell) subplot(2,numel(layersCell),i); hold on; grid on; ylim([0 100]); iterationsPerEpoch = floor(augmentedImdsTrain.NumObservations/options(i).MiniBatchSize); epoch = (1:numel(trainingInfo(i).TrainingAccuracy))/iterationsPerEpoch; plot(epoch,trainingInfo(i).TrainingAccuracy); plot(epoch,trainingInfo(i).ValidationAccuracy,'.k','MarkerSize',10); end subplot(2,numel(layersCell),1), ylabel('Accuracy');
然后,使用第二行子图绘制训练损失对轮次编号及验证损失的图。
for i=1:numel(layersCell) subplot(2,numel(layersCell),numel(layersCell) + i); hold on; grid on; ylim([0 10]); iterationsPerEpoch = floor(augmentedImdsTrain.NumObservations/options(i).MiniBatchSize); epoch = (1:numel(trainingInfo(i).TrainingAccuracy))/iterationsPerEpoch; plot(epoch,trainingInfo(i).TrainingLoss); plot(epoch,trainingInfo(i).ValidationLoss,'.k','MarkerSize',10); xlabel('Epoch'); end subplot(2,numel(layersCell),numel(layersCell)+1), ylabel('Loss');
在选择网络后,您可以使用 classify
并获得其基于测试数据 imdsTest
的准确度。
另请参阅
试验管理器 | augmentedImageDatastore
| imageDatastore
| parfeval
(Parallel Computing Toolbox) | fetchOutputs
| trainNetwork
| trainingOptions