Train Network Using Federated Learning
This example shows how to train a network using federated learning. Federated learning is a technique that enables you to train a network in a distributed, decentralized way [1].
Federated learning allows you to train a model using data from different sources without moving the data to a central location, even if the individual data sources do not match the overall distribution of the data set. This is known as non-independent and identically distributed (non-IID) data. Federated learning can be especially useful when the training data is large, or when there are privacy concerns about transferring the training data.
Instead of distributing data, the federated learning technique trains multiple models, each in the same location as a data source. You can create a global model that has learned from all the data sources by periodically collecting and combining the learnable parameters of the locally trained models. In this way, you can train a global model without centrally processing any training data.
This example uses federated learning to train a classification model in parallel using a highly non-IID dataset. The model is trained using the digits data set, which consists of 10000 handwritten images of the numbers 0 to 9. The example runs in parallel using 10 workers, each processing images of a single digit. By averaging the learnable parameters of the networks after each round of training, the models on each worker improve performance across all classes, without ever processing data of the other classes.
While data privacy is one of the applications of federated learning, this example does not deal with the details of maintaining data privacy and security. This example demonstrates the basic federated learning algorithm.
Set Up Parallel Environment
Create a parallel pool with the same number of workers as classes in the data set. For this example, use a process-based, local parallel pool with 10 workers.
cluster = parcluster("Processes");
cluster.NumWorkers = 10;
pool = parpool(cluster);
Starting parallel pool (parpool) using the 'Processes' profile ... Connected to parallel pool with 10 workers.
numWorkers = pool.NumWorkers;
Load Data Set
All data used in this example is initially stored in a centralized location. To make this data highly non-IID, you need to distribute the data among the workers according to class. To create validation and test data sets, transfer a portion of data from the workers to the client. After the data is correctly set up, with training data of individual classes on the workers and test and validation data of all classes on the client, there is no further transfer of data during training.
Specify the folder containing the image data.
digitDatasetPath = fullfile(matlabroot,"toolbox","nnet","nndemos", ... "nndatasets","DigitDataset");
Distribute the data among the workers. Each worker receives images of only one digit, such that worker 1 receives all the images of the number 0, worker 2 receives images of the number 1, etc.
Images of each digit are stored in a separate folder with the name of that digit. On each worker, use the fullfile
function to specify the path to a specific class folder. Then, create an imageDatastore
that contains all images of that digit. Next, use the splitEachLabel
function to randomly separate 30% of the data for use in validation and testing. Finally, create an augmentedImageDatastore
containing the training data.
inputSize = [28 28 1]; spmd digitDatasetPath = fullfile(digitDatasetPath,num2str(spmdIndex - 1)); imds = imageDatastore(digitDatasetPath, ... IncludeSubfolders=true, ... LabelSource="foldernames"); [imdsTrain,imdsTestVal] = splitEachLabel(imds,0.7,"randomized"); augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain); end
To test the performance of the combined global model during and after training, create test and validation datasets containing images from all classes. Combine the test and validation data from each worker into a single datastore. Then, split this datastore into two datastores that each contain 15% of the overall data - one for validating the network during training and the other for testing the network after training.
fileList = []; labelList = []; for i = 1:numWorkers tmp = imdsTestVal{i}; fileList = cat(1,fileList,tmp.Files); labelList = cat(1,labelList,tmp.Labels); end imdsGlobalTestVal = imageDatastore(fileList); imdsGlobalTestVal.Labels = labelList; [imdsGlobalTest,imdsGlobalVal] = splitEachLabel(imdsGlobalTestVal,0.5,"randomized"); augimdsGlobalTest = augmentedImageDatastore(inputSize(1:2),imdsGlobalTest); augimdsGlobalVal = augmentedImageDatastore(inputSize(1:2),imdsGlobalVal);
The data is now arranged such that each worker has data from a single class to train on, and the client holds validation and test data from all classes.
Define Network
Determine the number of classes in the data set.
classes = categories(imdsGlobalTest.Labels); numClasses = numel(classes);
Define the network architecture.
layers = [
imageInputLayer(inputSize,Normalization="none")
convolution2dLayer(5,32)
reluLayer
maxPooling2dLayer(2)
convolution2dLayer(5,64)
reluLayer
maxPooling2dLayer(2)
fullyConnectedLayer(numClasses)
softmaxLayer];
Create a dlnetwork
object from the layers.
net = dlnetwork(layers)
net = dlnetwork with properties: Layers: [9×1 nnet.cnn.layer.Layer] Connections: [8×2 table] Learnables: [6×3 table] State: [0×3 table] InputNames: {'imageinput'} OutputNames: {'softmax'} Initialized: 1 View summary with summary.
Define Model Loss Function
Create the function modelLoss
, listed in the Model Loss Function section of this example, that takes a dlnetwork
object and a mini-batch of input data with corresponding labels and returns the loss and the gradients of the loss with respect to the learnable parameters in the network.
Define Federated Averaging Function
Create the function federatedAveraging
, listed in the Federated Averaging Function section of this example, that takes the learnable parameters of the networks on each worker and the normalization factor for each worker, and returns the averaged learnable parameters across all the networks. Use the average learnable parameters to update the global network and the network on each worker.
Define Compute Accuracy Function
Create the function computeAccuracy
, listed in the Compute Accuracy Function section of this example, that takes a dlnetwork
object, a data set inside a minibatchqueue
object, and the list of classes, and returns the accuracy of the predictions across all observations in the data set.
Specify Training Options
During training, the workers periodically communicate their network learnable parameters to the client, so that the client can update the global model. The training is divided into rounds. At the end of each round of training, the learnable parameters are averaged and the global model is updated. The worker models are then replaced with the new global model, and training continues on the workers.
Train for 300 rounds, with 5 epochs per round. Training for a small number of epochs per round ensures that the networks on the workers do not diverge too far before they are averaged.
numRounds = 300; numEpochsperRound = 5; miniBatchSize = 100;
Specify the options for SGDM optimization. Specify an initial learn rate of 0.001 and momentum 0.
learnRate = 0.001; momentum = 0;
Train Model
Create a function handle to the custom mini-batch preprocessing function preprocessMiniBatch
(defined in the Mini-Batch Preprocessing Function section of this example).
On each worker, find the total number of training observations processed locally on that worker. Use this number to normalize the learnable parameters on each worker when you find the average learnable parameters after each communication round. This helps to balance the average if there is a difference between the amount of data on each worker.
On each worker, create a minibatchqueue
object that processes and manages mini-batches of images during training. For each mini-batch:
Preprocess the data using the custom mini-batch preprocessing function
preprocessMiniBatch
to convert the labels to one-hot encoded variables.Format the image data with the dimension labels
'SSCB'
(spatial, spatial, channel, batch). By default, theminibatchqueue
object converts the data todlarray
objects with underlying typesingle
. Do not add a format to the class labels.Train on a GPU if one is available. By default, the
minibatchqueue
object converts each output to agpuArray
if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
preProcess = @(x,y)preprocessMiniBatch(x,y,classes); spmd sizeOfLocalDataset = augimdsTrain.NumObservations; mbq = minibatchqueue(augimdsTrain, ... MiniBatchSize=miniBatchSize, ... MiniBatchFcn=preProcess, ... MiniBatchFormat=["SSCB",""]); end
Create a minibatchqueue
object that manages the validation data to use during training. Use the same settings as the minibatchqueue
on each worker.
mbqGlobalVal = minibatchqueue(augimdsGlobalVal, ... MiniBatchSize=miniBatchSize, ... MiniBatchFcn=preProcess, ... MiniBatchFormat=["SSCB",""]);
Initialize the trainingProgressMonitor
object. Because the timer starts when you create the monitor, make sure that you create the object close to the training loop.
monitor = trainingProgressMonitor( ... Metrics="GlobalAccuracy", ... Info="CommunicationRound", ... XLabel="Communication Round");
Initialize the velocity parameter for the SGDM solver.
velocity = [];
Initialize the global model. To start, the global model has the same initial parameters as the untrained network on each worker.
globalModel = net;
Train the model using a custom training loop. For each communication round,
Update the networks on the workers with the latest global network.
Train the networks on the workers for five epochs.
Find the average parameters of all the networks using the
federatedAveraging
function.Replace the global network parameters with the average value.
Calculate the accuracy of the updated global network using the validation data.
Update the global accuracy in the training progress monitor.
Stop if the
Stop
property istrue
. TheStop
property value of theTrainingProgressMonitor
object changes totrue
when you click the Stop button.
For each epoch, shuffle the data and loop over mini-batches of data. For each mini-batch:
Evaluate the model loss and gradients using the
dlfeval
andmodelLoss
functions.Update the local network parameters using the
sgdmupdate
function.
round = 0; while round < numRounds && ~monitor.Stop round = round + 1; spmd % Send global updated parameters to each worker. net.Learnables.Value = globalModel.Learnables.Value; % Loop over epochs. for epoch = 1:numEpochsperRound % Shuffle data. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) % Read mini-batch of data. [X,T] = next(mbq); % Evaluate the model loss and gradients using dlfeval and the % modelLoss function. [loss,gradients] = dlfeval(@modelLoss,net,X,T); % Update the network parameters using the SGDM optimizer. [net,velocity] = sgdmupdate(net,gradients,velocity,learnRate,momentum); end end % Collect updated learnable parameters on each worker. workerLearnables = net.Learnables.Value; end % Find normalization factors for each worker based on ratio of data % processed on that worker. sizeOfAllDatasets = sum([sizeOfLocalDataset{:}]); normalizationFactor = [sizeOfLocalDataset{:}]/sizeOfAllDatasets; % Update the global model with new learnable parameters, normalized and % averaged across all workers. globalModel.Learnables.Value = federatedAveraging(workerLearnables,normalizationFactor); % Calculate the accuracy of the global model. accuracy = computeAccuracy(globalModel,mbqGlobalVal,classes); % Update the training progress monitor. recordMetrics(monitor,round,GlobalAccuracy=accuracy); updateInfo(monitor,CommunicationRound=round + " of " + numRounds); monitor.Progress = 100*round/numRounds; end
After the final round of training, update the network on each worker with the final average learnable parameters. This is important if you want to continue to use or train the network on the workers.
spmd net.Learnables.Value = globalModel.Learnables.Value; end
Test Model
Test the classification accuracy of the model by comparing the predictions on the test set with the true labels.
Create a minibatchqueue
object that manages the test data. Use the same settings as the minibatchqueue
objects used during training and validation.
mbqGlobalTest = minibatchqueue(augimdsGlobalTest, ... MiniBatchSize=miniBatchSize, ... MiniBatchFcn=preProcess, ... MiniBatchFormat="SSCB");
Use the computeAccuracy
function to compute the predicted classes and calculate the accuracy of the predictions across all the test data.
accuracy = computeAccuracy(globalModel,mbqGlobalTest,classes)
accuracy = single
0.9827
After you are done with your computations, you can delete your parallel pool. The gcp
function returns the current parallel pool object so you can delete the pool.
delete(gcp("nocreate"));
Model Loss Function
The modelLoss
function takes a dlnetwork
object net
, a mini-batch of input data X
with corresponding labels T
and returns the loss and the gradients of the loss with respect to the learnable parameters in net
. To compute the gradients automatically, use the dlgradient
function. To compute predictions of the network during training, use the forward
function.
function [loss,gradients] = modelLoss(net,X,T) YPred = forward(net,X); loss = crossentropy(YPred,T); gradients = dlgradient(loss,net.Learnables); end
Compute Accuracy Function
The computeAccuracy
function takes a dlnetwork
object net
, a minibatchqueue
object mbq
, and the list of classes, and returns the accuracy of all the predictions on the data set provided. To compute predictions of the network during validation or after training is finished, use the predict
function.
function accuracy = computeAccuracy(net,mbq,classes) correctPredictions = []; shuffle(mbq); while hasdata(mbq) [XTest,TTest] = next(mbq); TTest = onehotdecode(TTest,classes,1)'; YPred = predict(net,XTest); YPred = onehotdecode(YPred,classes,1)'; correctPredictions = [correctPredictions; YPred == TTest]; end predSum = sum(correctPredictions); accuracy = single(predSum./size(correctPredictions,1)); end
Mini-Batch Preprocessing Function
The preprocessMiniBatch
function preprocesses the data using the following steps:
Extract the image data from the incoming cell array and concatenate into a numeric array. Concatenating the image data over the fourth dimension adds a third dimension to each image, to be used as a singleton channel dimension.
Extract the label data from the incoming cell arrays and concatenate into a categorical array along the second dimension.
One-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.
function [X,Y] = preprocessMiniBatch(XCell,YCell,classes) % Concatenate. X = cat(4,XCell{1:end}); % Extract label data from cell and concatenate. Y = cat(2,YCell{1:end}); % One-hot encode labels. Y = onehotencode(Y,1,ClassNames=classes); end
Federated Averaging Function
The federatedAveraging
function takes the learnable parameters of the networks on each worker and the normalization factor for each worker, and returns the averaged learnable parameters across all the networks. Use the average learnable parameters to update the global network and the network on each worker.
function learnables = federatedAveraging(workerLearnables,normalizationFactor) numWorkers = size(normalizationFactor,2); % Initialize container for averaged learnables with same size as existing % learnables. Use learnables of first worker network as an example. exampleLearnables = workerLearnables{1}; learnables = cell(height(exampleLearnables),1); for i = 1:height(learnables) learnables{i} = zeros(size(exampleLearnables{i}),"like",(exampleLearnables{i})); end % Add the normalized learnable parameters of all workers to % calculate average values. for i = 1:numWorkers tmp = workerLearnables{i}; for values = 1:numel(learnables) learnables{values} = learnables{values} + normalizationFactor(i).*tmp{values}; end end end
References
See Also
dlarray
| dlnetwork
| sgdmupdate
| dlupdate
| dlfeval
| dlgradient
| minibatchqueue