Error in DeepLearning using googlenet
12 次查看(过去 30 天)
显示 更早的评论
I am using transfer learning using Googlenet for binary classification of images. The following is the code.
imds = imageDatastore('RGBNewFrames_sameSz', ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
%% Divide the data into training and validation data sets.
% Use 70% of the images for training and 30% for validation
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
%% Load Pretrained Network
net = googlenet;
analyzeNetwork(net)
inputSize = net.Layers(1).InputSize
%% Replace Final Layers
% The last three layers of the pretrained network net are configured for 1000 classes.
% These three layers must be fine-tuned for the new classification problem.
% Extract all layers, except the last three, from the pretrained network.
layersTransfer = net.Layers(1:end-3);
% transfer the layers to the new classification task by replacing the last three layers with a fully connected layer, a softmax layer, and a classification output layer.
numClasses = numel(categories(imdsTrain.Labels))
layers = [
layersTransfer
fullyConnectedLayer(numClasses,'WeightLearnRateFactor',10,'BiasLearnRateFactor',10,'Name', 'FC_Layer')
softmaxLayer('Name', 'soft_mxLayer')
classificationLayer('Name', 'OutputLayer')];
%% Train Network
pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter( ...
'RandXReflection',true, ...
'RandXTranslation',pixelRange, ...
'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
'DataAugmentation',imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:3),imdsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',2, ...%6
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
netTransfer = trainNetwork(augimdsTrain,layers,options);
%% Classify Validation Images
[YPred,scores] = classify(netTransfer,augimdsValidation);
YValidation = imdsValidation.Labels;
accuracy = mean(YPred == YValidation)
Was Getting the error
Error using trainNetwork (line 170)
Invalid network.
Error in Ex2_Transfer_learning_googlenet (line 73)
netTransfer = trainNetwork(augimdsTrain,lgraph,options);
Caused by:
Layer 'inception_3a-3x3_reduce': Input size mismatch. Size of input to this layer is different from the expected input size.
Inputs to this layer:
from layer 'inception_3a-relu_1x1' (28×28×64 output)
Layer 'inception_3a-output': Unused input. Each layer input must be connected to the output of another layer.
Detected unused inputs:
input 'in2'
input 'in3'
input 'in4'
Layer 'inception_3b-output': Unused input. Each layer input must be connected to the output of another layer.
Detected unused inputs:
input 'in2'
So used the solution provided in the link https://in.mathworks.com/matlabcentral/answers/411767-trainnetwork-invalid-network
Based on this made the following changes to the code
graph = layerGraph(layers);
netTransfer = trainNetwork(augimdsTrain,lgraph,options);
This resulted in the error,
Layer names in layer array must be non- empty.
So checked and found fcn, softmax did not have names and Explicity assigned names for these layers and executed the code.
The error on Layer names got sorted but now getting the old error again as below.(Error has also been mentioned above for reference)
Input size mismatch. Size of input to this layer is different from the expected input size.
Kindly help me sort this error. I tried all possible things I could think of to debug the error but could not solve it.
Getting the same error with resnet also. Presently it is working only with alexnet.
Any help on this will be greatly appreciated.
Thanks
0 个评论
采纳的回答
Srivardhan Gadila
2021-11-15
As David suggested you can use analyzeNetwork function to analyze your new network and debug the issue.
The issue with the workflow you are following is that, GoogleNet is a dagnetwork and when you are just collecting all the required layers excluding the last 3 layers in the "layersTransfer" array, you are only collecting the layers and information of the individual connections (Connections) is lost here.
layersTransfer = net.Layers(1:end-3);
So when you pass this new layers array to the trainNetwork function, it assumes that the network is a seriesNetwork and all the layers are connected serially. You can check the same using the analyzeNetwork function.
layers = [
layersTransfer
fullyConnectedLayer(numClasses,'WeightLearnRateFactor',10,'BiasLearnRateFactor',10,'Name', 'FC_Layer')
softmaxLayer('Name', 'soft_mxLayer')
classificationLayer('Name', 'OutputLayer')];
analyzeNetwork(layers)
Hence the correct workflow would be as follows:
Get the layerGraph instead of layers from the pretrained network:
lgraph = layerGraph(net);
Replace the fullyConnectedLayer and the classificationLayer with your newly defined layers:
fcLayer = fullyConnectedLayer(numClasses,'WeightLearnRateFactor',10,'BiasLearnRateFactor',10,'Name', 'FC_Layer');
clsLayer = classificationLayer('Name', 'OutputLayer');
lgraphNew = replaceLayer(lgraph,"loss3-classifier",fcLayer);
lgraphNew = replaceLayer(lgraphNew,"output",clsLayer)
analyzeNetwork(lgraphNew)
Now train this newly created layerGraph "lgraphNew" with the trainNetwork function.
For more information refer to SeriesNetwork, Layer, DAGNetwork, layerGraph, Object Functions of layerGraph.
10 个评论
Srivardhan Gadila
2021-11-19
You can refer to the topics, examples and functions from the following pages:
更多回答(1 个)
David Willingham
2021-11-15
Hi,
I'd recommend running the network analyser to see where the issues are with the modifications that have been made to the network. Here's the doc page for it:
Note: You can also visualize the network in the Deep Network Designer (the network analyser can also be launched from within it too): https://www.mathworks.com/help/deeplearning/ref/deepnetworkdesigner-app.html
What information does the network analyser provide you?
2 个评论
Srivardhan Gadila
2021-11-16
@Sushma TV like I said in the answer below, to check the issues with your newly created network, you should use the analyzeNetwork function for it as well.
analyzeNetwork(layers)
As the pretrained network does not have any issues, using analyzeNetwork function on the pretrained network would not display any errors.
另请参阅
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!