The example of Train Network Using Federated Learning that is given in MathWorks documentation is not working
26 次查看(过去 30 天)
显示 更早的评论
I am trying to run the example of Train Network Using Federated Learning given in the MathWorks documentation ( https://in.mathworks.com/help/deeplearning/ug/train-network-using-federated-learning.html ). But I am getting the following error:
Though it showing that Undefined function 'preprocessMiniBatch' . I have included the following function of 'preprocessMiniBatch' as given in the Mathworks documentation
Everytime I run the code I am getting this error. I am unable to understand where I am making a mistake. I am using MATLAB 2023a version, CPU and 8GB RAM. I am looking for a solution since a month now. Someone please help me in solving this problem. I will be grateful to you.
2 个评论
Walter Roberson
2023-12-25
If you are not already doing so, try putting preprocessMiniBatch into its own .m file
回答(1 个)
Harsha Vardhan
2024-1-5
Hi Debojit Sharma,
I understand that you faced an error while using the Federated Learning example. It appears that you were able to resolve this issue following a comment from community. Later, you wanted to plot a confusion matrix for this example.
Confusion Matrix can be plotted using the ‘ confusionmat’ function. Please check the relevant documentation here - https://www.mathworks.com/help/stats/confusionmat.html
To integrate confusion matrix computation for training and testing phases into your existing federated learning code, you can collect predictions and actual labels from the global model for both training and testing datasets and then use these to compute the confusion matrices. There are other possible ways of calculaing the confusion matrix too.
You can check the code mofifications below.
Just like datastores were created for test and validation data, create a datastore for training data as below.
fileList = [];
labelList = [];
for i = 1:numWorkers
tmp = imdsTestVal{i};
fileList = cat(1,fileList,tmp.Files);
labelList = cat(1,labelList,tmp.Labels);
end
imdsGlobalTestVal = imageDatastore(fileList);
imdsGlobalTestVal.Labels = labelList;
[imdsGlobalTest,imdsGlobalVal] = splitEachLabel(imdsGlobalTestVal,0.5,"randomized");
augimdsGlobalTest = augmentedImageDatastore(inputSize(1:2),imdsGlobalTest);
augimdsGlobalVal = augmentedImageDatastore(inputSize(1:2),imdsGlobalVal);
%% Code for creating a datastore for training data
fileList = [];
labelList = [];
for i = 1:numWorkers
tmp = imdsTrain{i};
fileList = cat(1,fileList,tmp.Files);
labelList = cat(1,labelList,tmp.Labels);
end
imdsGlobalTrainVal = imageDatastore(fileList);
imdsGlobalTrainVal.Labels = labelList;
augimdsGlobalTrain = augmentedImageDatastore(inputSize(1:2),imdsGlobalTrainVal);
Similarly, create a 'minibatchqueue' object for training data.
mbqGlobalVal = minibatchqueue(augimdsGlobalVal, ...
MiniBatchSize=miniBatchSize, ...
MiniBatchFcn=preProcess, ...
MiniBatchFormat=["SSCB",""]);
%Code for creating a minibatchqueue for training data
mbqGlobalTrain = minibatchqueue(augimdsGlobalTrain, ...
MiniBatchSize=miniBatchSize, ...
MiniBatchFcn=preProcess, ...
MiniBatchFormat=["SSCB",""]);
After calculating the accuracy, you can plot the confusion matrices for all the training and testing data as below.
accuracy = computeAccuracy(globalModel,mbqGlobalTest,classes);
%Code for displaying training confusion matrix
trainConfusionMat = createConfusionMatrix(globalModel, mbqGlobalTrain, classes);
figure;
confusionchart(trainConfusionMat);
title('Training Confusion Matrix');
%Code for displaying testing confusion matrix
testConfusionMat = createConfusionMatrix(globalModel, mbqGlobalTest, classes);
figure;
confusionchart(testConfusionMat);
title('Testing Confusion Matrix');
The below function creates a confusion mattrix using the 'confusionmat' MATLAB function.
%function for calculating Confusion Matrix
function confusionMat = createConfusionMatrix(net, mbq, classes)
allYPred = [];
allTTest = [];
shuffle(mbq);
while hasdata(mbq)
[XTest, TTest] = next(mbq);
TTest = onehotdecode(TTest, classes, 1)';
YPred = predict(net, XTest);
YPred = onehotdecode(YPred, classes, 1)';
allTTest = [allTTest; TTest];
allYPred = [allYPred; YPred];
end
confusionMat = confusionmat(categorical(allTTest), categorical(allYPred));
end
Hope this helps in resolving your query!
0 个评论
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Sequence and Numeric Feature Data Workflows 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!