Main Content

Use parfeval to Train Multiple Deep Learning Networks

This example shows how to use parfeval to perform a parameter sweep on the depth of the network architecture for a deep learning network and retrieve data during training.

Deep learning training often takes hours or days, and searching for good architectures can be difficult. With parallel computing, you can speed up and automate your search for good models. If you have access to a machine with multiple graphical processing units (GPUs), you can complete this example on a local copy of the data set with a local parallel pool. If you want to use more resources, you can scale up deep learning training to the cloud. This example shows how to use parfeval to perform a parameter sweep on the depth of a network architecture in a cluster in the cloud. Using parfeval allows you to train in the background without blocking MATLAB, and provides options to stop early if results are satisfactory. You can modify the script to do a parameter sweep on any other parameter. Also, this example shows how to obtain feedback from the workers during computation by using DataQueue.

Requirements

Before you can run this example, you need to configure a cluster and upload your data to the Cloud. In MATLAB, you can create clusters in the cloud directly from the MATLAB Desktop. On the Home tab, in the Parallel menu, select Create and Manage Clusters. In the Cluster Profile Manager, click Create Cloud Cluster. Alternatively, you can use MathWorks Cloud Center to create and access compute clusters. For more information, see Getting Started with Cloud Center. For this example, ensure that your cluster is set as default on the MATLAB Home tab, in Parallel > Select a Default Cluster. After that, upload your data to an Amazon S3 bucket and use it directly from MATLAB. This example uses a copy of the CIFAR-10 data set that is already stored in Amazon S3. For instructions, see Work with Deep Learning Data in AWS.

Load Data Set from the Cloud

Load the training and test data sets from the cloud using imageDatastore. Split the training data set into training and validation sets, and keep the test data set to test the best network from the parameter sweep. In this example, you use a copy of the CIFAR-10 data set stored in Amazon S3. To ensure that the workers have access to the datastore in the cloud, make sure that the environment variables for the AWS credentials are set correctly. See Work with Deep Learning Data in AWS.

imds = imageDatastore("s3://cifar10cloud/cifar10/train", ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

imdsTest = imageDatastore("s3://cifar10cloud/cifar10/test", ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);

Train the network with augmented image data by creating an augmentedImageDatastore object. Use random translations and horizontal reflections. Data augmentation helps prevent the network from overfitting and memorizing the exact details of the training images.

imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    RandXReflection=true, ...
    RandXTranslation=pixelRange, ...
    RandYTranslation=pixelRange);
augmentedImdsTrain=augmentedImageDatastore(imageSize,imdsTrain, ...
    DataAugmentation=imageAugmenter, ...
    OutputSizeMode="randcrop");

Train Several Networks Simultaneously

Specify the training options. Set the mini-batch size and scale the initial learning rate linearly according to the mini-batch size. Set the validation frequency so that trainnet validates the network once per epoch.

miniBatchSize = 128;
initialLearnRate = 1e-1 * miniBatchSize/256;
validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize);
options = trainingOptions("sgdm", ...
    MiniBatchSize=miniBatchSize, ... % Set the mini-batch size
    Verbose=false, ... % Do not send command line output.
    Metrics="accuracy", ...
    InitialLearnRate=initialLearnRate, ... % Set the scaled learning rate.
    L2Regularization=1e-10, ...
    MaxEpochs=30, ...
    Shuffle="every-epoch", ...
    ValidationData=imdsValidation, ...
    ValidationFrequency=validationFrequency);

Specify the depths for the network architecture on which to do a parameter sweep. Perform a parallel parameter sweep training several networks simultaneously using parfeval. Use a loop to iterate through the different network architectures in the sweep. Create the helper function createNetworkArchitecture at the end of the script, which takes an input argument to control the depth of the network and creates an architecture for CIFAR-10. Use parfeval to offload the computations performed by trainnet to a worker in the cluster. parfeval returns a future variable to hold the trained networks and training information when computations are done.

By default, the trainnet function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainnet function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.

netDepths = 1:4;
numExperiments = numel(netDepths);
for idx = 1:numExperiments
    networksFuture(idx) = parfeval(@trainnet,2, ...
        augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),"crossentropy",options);
end
Starting parallel pool (parpool) using the 'MyCluster' profile ...
Connected to parallel pool with 4 workers (PreferredPoolNumWorkers).

parfeval does not block MATLAB, which means you can continue executing commands. In this case, obtain the trained networks and their training information by using fetchOutputs on networksFuture. The fetchOutputs function waits until the future variables finish.

[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);

Obtain the final validation accuracies of the networks by accessing the trainingInfo structure.

for idx = 1:numExperiments
    validationHistory = trainingInfo(idx).ValidationHistory;
    accuracies(idx) = validationHistory.Accuracy(end);
end

accuracies
accuracies = 1×4

   70.7200   78.8200   76.1000   78.0200

Select the best network in terms of accuracy.

[~, I] = max(accuracies);
bestNetwork = trainedNetworks(I(1));

Test its performance against the test data set. To make predictions with multiple observations, use the minibatchpredict function. To convert the prediction scores to labels, use the scores2label function. The minibatchpredict function automatically uses a GPU if one is available.

classNames = categories(imdsTest.Labels);
scores = minibatchpredict(bestNetwork,imdsTest);
Y = scores2label(scores,classNames);
accuracy = sum(Y == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy = 0.7798

Calculate the confusion matrix for the test data.

figure
confusionchart(imdsTest.Labels,Y,RowSummary="row-normalized",ColumnSummary="column-normalized");

Send Feedback Data During Training

Prepare and initialize plots that show the training progress in each of the workers. Use animatedLine for a convenient way to show changing data.

f = figure;
f.Visible = true;
for i=1:4
    subplot(2,2,i)
    xlabel("Iteration");
    ylabel("Training accuracy");
    lines(i) = animatedline;
end

Send the training progress data from the workers to the client by using DataQueue, and then plot the data. Update the plots each time the workers send training progress feedback by using afterEach. The parameter opts contains information about the worker, training iteration, and training accuracy.

D = parallel.pool.DataQueue;
afterEach(D, @(opts) updatePlot(lines,opts{:}));

Specify the depths for the network architecture on which to do a parameter sweep, and perform the parallel parameter sweep using parfeval. Allow the workers to access any helper function in this script, by adding the script to the current pool as an attached file. Define an output function in the training options to send the training progress from the workers to the client. The training options depend on the index of the worker and must be included inside the for loop.

netDepths = 1:4;
addAttachedFiles(gcp,mfilename);
for idx = 1:numel(netDepths)
    
    miniBatchSize = 128;
    initialLearnRate = 1e-1 * miniBatchSize/256; % Scale the learning rate according to the mini-batch size.
    validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize);
    
    options = trainingOptions("sgdm", ...
        OutputFcn=@(state) sendTrainingProgress(D,idx,state), ... % Set an output function to send intermediate results to the client.
        MiniBatchSize=miniBatchSize, ... % Set the corresponding MiniBatchSize in the sweep.
        Verbose=false, ... % Do not send command line output.
        InitialLearnRate=initialLearnRate, ... % Set the scaled learning rate.
        Metrics="accuracy", ...
        L2Regularization=1e-10, ...
        MaxEpochs=30, ...
        Shuffle="every-epoch", ...
        ValidationData=imdsValidation, ...
        ValidationFrequency=validationFrequency);
    
    networksFuture(idx) = parfeval(@trainnet,2, ...
        augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),"crossentropy",options);
end

parfeval invokes trainnet on a worker in the cluster. Computations happen on the background, so you can continue working in MATLAB. If you want to stop a parfeval computation, you can call cancel on its corresponding future variable. For example, if you observe that a network is underperforming, you can cancel its future. When you do so, the next queued future variable starts its computations.

In this case, fetch the trained networks and their training information by invoking fetchOutputs on the future variables.

[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);

Obtain the final validation accuracy for each network.

for idx = 1:numExperiments
    validationHistory = trainingInfo(idx).ValidationHistory;
    accuracies(idx) = validationHistory.Accuracy(end);
end

accuracies
accuracies = 1×4

   71.4600   78.3600   74.4000   79.3800

Helper Functions

Define a network architecture for the CIFAR-10 data set with a function, and use an input argument to adjust the depth of the network. To simplify the code, use convolutional blocks that convolve the input. The pooling layers downsample the spatial dimensions.

function layers = createNetworkArchitecture(netDepth)
imageSize = [32 32 3];
netWidth = round(16/sqrt(netDepth)); % netWidth controls the number of filters in a convolutional block

layers = [
    imageInputLayer(imageSize)
    
    convolutionalBlock(netWidth,netDepth)
    maxPooling2dLayer(2,Stride=2)
    convolutionalBlock(2*netWidth,netDepth)
    maxPooling2dLayer(2,Stride=2)
    convolutionalBlock(4*netWidth,netDepth)
    averagePooling2dLayer(8)
    
    fullyConnectedLayer(10)
    softmaxLayer
    ];
end

Define a function to create a convolutional block in the network architecture.

function layers = convolutionalBlock(numFilters,numConvLayers)
layers = [
    convolution2dLayer(3,numFilters,Padding="same")
    batchNormalizationLayer
    reluLayer
    ];

layers = repmat(layers,numConvLayers,1);
end

Define a function to send the training progress to the client through DataQueue.

function stop = sendTrainingProgress(D,idx,info)
if info.State == "iteration" && ~isempty(info.TrainingAccuracy)
    send(D,{idx,info.Iteration,info.TrainingAccuracy});
end
stop = false;
end

Define an update function to update the plots when a worker sends an intermediate result.

function updatePlot(lines,idx,iter,acc)
addpoints(lines(idx),iter,acc);
drawnow limitrate nocallbacks
end

See Also

(Parallel Computing Toolbox) | | | | |

Related Topics