How to change output classes for image classification in alexnet?

4 次查看(过去 30 天)
Hello everyone
I'm trying to implement the following example:
but I got this error:
Error using trainNetwork (line 170)
Invalid training data. The output size (28224) of the last layer does not match the number of classes (24).
and this is my code:
tic
clear;
clc;
outputFolder = fullfile('Spilt_Dataset');
rootFolder = fullfile(outputFolder , 'Categories');
categories = {'front_or_left' , 'front_or_right','hump' , 'left_turn','narrow_from_left' , 'narrows_from_right',...
'no_horron' , 'no_parking','no_u_turn' , 'overtaking_is_forbidden','parking' , 'pedestrian_crossing',...
'right_or_left' , 'right_turn','rotor' , 'slow','speed_30' , 'speed_40',...
'speed_50','speed_60','speed_80','speed_100','stop','u_turn'};
imds = imageDatastore(fullfile(rootFolder,categories),'LabelSource','Foldernames');
tbl = countEachLabel(imds);
minSetCount = min(tbl{:,2});
imds = splitEachLabel(imds, minSetCount, 'randomize');
countEachLabel(imds);
front_or_left = find(imds.Labels == 'front_or_left',1);
front_or_right = find(imds.Labels == 'front_or_right',1);
hump = find(imds.Labels == 'hump',1);
left_turn = find(imds.Labels == 'left_turn',1);
narrow_from_left = find(imds.Labels == 'narrow_from_left',1);
narrows_from_right = find(imds.Labels == 'narrows_from_right',1);
no_horron = find(imds.Labels == 'no_horron',1);
no_parking = find(imds.Labels == 'no_parking',1);
no_u_turn = find(imds.Labels == 'no_u_turn',1);
overtaking_is_forbidden = find(imds.Labels == 'overtaking_is_forbidden',1);
parking = find(imds.Labels == 'parking',1);
pedestrian_crossing = find(imds.Labels == 'pedestrian_crossing',1);
right_or_left = find(imds.Labels == 'right_or_left',1);
right_turn = find(imds.Labels == 'right_turn',1);
rotor = find(imds.Labels == 'rotor',1);
slow = find(imds.Labels == 'slow',1);
speed_30 = find(imds.Labels == 'speed_30',1);
speed_40 = find(imds.Labels == 'speed_40',1);
speed_50 = find(imds.Labels == 'speed_50',1);
speed_60 = find(imds.Labels == 'speed_60',1);
speed_80 = find(imds.Labels == 'speed_80',1);
speed_100 = find(imds.Labels == 'speed_100',1);
stop = find(imds.Labels == 'stop',1);
u_turn = find(imds.Labels == 'u_turn',1);
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
net = googlenet;
inputSize = net.Layers(1).InputSize;
lgraph = layerGraph(net);
numClasses = numel(categories(imdsTrain.Labels))
newLearnableLayer = fullyConnectedLayer(numClasses, ...
'Name','new_fc', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,'loss3-classifier',newLearnableLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'output',newClassLayer);
pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter( ...
'RandXReflection',true, ...
'RandXTranslation',pixelRange, ...
'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
'DataAugmentation',imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
netTransfer = trainNetwork(augimdsTrain,lgraph,options);

采纳的回答

Srivardhan Gadila
Srivardhan Gadila 2021-4-16
If the classes listed from line "front_or_left = find(imd...." in the above code are the only classes which are 24 in total then the issue might be that the value returned by
numClasses = numel(categories(imdsTrain.Labels))
is probably 28224 and not 24. If that's the case then specify the value 24 directly in the line
newLearnableLayer = fullyConnectedLayer(24, ...
and try running the code.
You can try debugging the issue by checking the layer activation size by using analyzeNetwork function.
If the network is fine then the issue might be w.r.t your data i.e., labels itself.
If you still face the issue then Contact MathWorks Support.

更多回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Image Data Workflows 的更多信息

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by