How to change output classes for image classification in alexnet?
4 次查看(过去 30 天)
显示 更早的评论
Hello everyone
I'm trying to implement the following example:
but I got this error:
Error using trainNetwork (line 170)
Invalid training data. The output size (28224) of the last layer does not match the number of classes (24).
and this is my code:
tic
clear;
clc;
outputFolder = fullfile('Spilt_Dataset');
rootFolder = fullfile(outputFolder , 'Categories');
categories = {'front_or_left' , 'front_or_right','hump' , 'left_turn','narrow_from_left' , 'narrows_from_right',...
'no_horron' , 'no_parking','no_u_turn' , 'overtaking_is_forbidden','parking' , 'pedestrian_crossing',...
'right_or_left' , 'right_turn','rotor' , 'slow','speed_30' , 'speed_40',...
'speed_50','speed_60','speed_80','speed_100','stop','u_turn'};
imds = imageDatastore(fullfile(rootFolder,categories),'LabelSource','Foldernames');
tbl = countEachLabel(imds);
minSetCount = min(tbl{:,2});
imds = splitEachLabel(imds, minSetCount, 'randomize');
countEachLabel(imds);
front_or_left = find(imds.Labels == 'front_or_left',1);
front_or_right = find(imds.Labels == 'front_or_right',1);
hump = find(imds.Labels == 'hump',1);
left_turn = find(imds.Labels == 'left_turn',1);
narrow_from_left = find(imds.Labels == 'narrow_from_left',1);
narrows_from_right = find(imds.Labels == 'narrows_from_right',1);
no_horron = find(imds.Labels == 'no_horron',1);
no_parking = find(imds.Labels == 'no_parking',1);
no_u_turn = find(imds.Labels == 'no_u_turn',1);
overtaking_is_forbidden = find(imds.Labels == 'overtaking_is_forbidden',1);
parking = find(imds.Labels == 'parking',1);
pedestrian_crossing = find(imds.Labels == 'pedestrian_crossing',1);
right_or_left = find(imds.Labels == 'right_or_left',1);
right_turn = find(imds.Labels == 'right_turn',1);
rotor = find(imds.Labels == 'rotor',1);
slow = find(imds.Labels == 'slow',1);
speed_30 = find(imds.Labels == 'speed_30',1);
speed_40 = find(imds.Labels == 'speed_40',1);
speed_50 = find(imds.Labels == 'speed_50',1);
speed_60 = find(imds.Labels == 'speed_60',1);
speed_80 = find(imds.Labels == 'speed_80',1);
speed_100 = find(imds.Labels == 'speed_100',1);
stop = find(imds.Labels == 'stop',1);
u_turn = find(imds.Labels == 'u_turn',1);
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
net = googlenet;
inputSize = net.Layers(1).InputSize;
lgraph = layerGraph(net);
numClasses = numel(categories(imdsTrain.Labels))
newLearnableLayer = fullyConnectedLayer(numClasses, ...
'Name','new_fc', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,'loss3-classifier',newLearnableLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'output',newClassLayer);
pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter( ...
'RandXReflection',true, ...
'RandXTranslation',pixelRange, ...
'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
'DataAugmentation',imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
netTransfer = trainNetwork(augimdsTrain,lgraph,options);
0 个评论
采纳的回答
Srivardhan Gadila
2021-4-16
If the classes listed from line "front_or_left = find(imd...." in the above code are the only classes which are 24 in total then the issue might be that the value returned by
numClasses = numel(categories(imdsTrain.Labels))
is probably 28224 and not 24. If that's the case then specify the value 24 directly in the line
newLearnableLayer = fullyConnectedLayer(24, ...
and try running the code.
You can try debugging the issue by checking the layer activation size by using analyzeNetwork function.
If the network is fine then the issue might be w.r.t your data i.e., labels itself.
更多回答(0 个)
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Image Data Workflows 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!