need to plot the accuracy vs epoch graph
显示 更早的评论
allImages = imageDatastore('TrainingData', 'IncludeSubfolders', true,...
'LabelSource', 'foldernames');
%% Split data into training and test sets
[trainingImages, testImages] = splitEachLabel(allImages, 0.8, 'randomize');
%% Load Pre-trained Network (AlexNet)
% AlexNet is a pre-trained network trained on 1000 object categories.
% AlexNet is avaliable as a support package on FileExchange.
alex = alexnet;
%% Review Network Architecture
layers = alex.Layers
%% Modify Pre-trained Network
% AlexNet was trained to recognize 1000 classes, we need to modify it to
% recognize just 4 classes.
layers(23) = fullyConnectedLayer(4); % change this based on # of classes
layers(25) = classificationLayer
%% Perform Transfer Learning
% For transfer learning we want to change the weights of the network ever so slightly. How
% much a network is changed during training is controlled by the learning
% rates.
opts = trainingOptions('sgdm', 'InitialLearnRate', 0.001,...
'MaxEpochs', 5, 'MiniBatchSize', 16);
%% Set custom read function
% One of the great things about imageDataStore it lets you specify a
% "custom" read function, in this case it is simply resizing the input
% images to 227x227 pixels which is what AlexNet expects. You can do this by
% specifying a function handle of a function with code to read and
% pre-process the image.
trainingImages.ReadFcn = @readFunctionTrain;
%% Train the Network
% This process usually takes about 5-20 minutes on a desktop GPU.
myNet = trainNetwork(trainingImages, layers, opts);
%% Test Network Performance
% Now let's the test the performance of our new "snack recognizer" on the test set.
testImages.ReadFcn = @readFunctionTrain;
predictedLabels = classify(myNet, testImages);
accuracy = mean(predictedLabels == testImages.Labels)
confusionchart(predictedLabels, testImages.Labels)
Hello, for the code above, I need to plot the accuracy vs epoch graph. How can I do that? Thank you!
回答(1 个)
Joss Knight
2022-11-14
1 个投票
FWIW, you shouldn't use ReadFcn for resizing images, it dramatically slows down file access. Use augmentedImageDatastore instead.
类别
在 帮助中心 和 File Exchange 中查找有关 Get Started with Statistics and Machine Learning Toolbox 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!