Faster R-CNN layer error

18 次查看(过去 30 天)
Brenon
Brenon 2024-7-20,23:34
编辑: Walter Roberson 2024-7-21,4:30
I have been at it for awhile but cannot figure it out, and I have also gotten lost in documentation for a few days now. I am attemping to train a Faster R-CNN model with a pretrained ResNet backbone. So far I have found documentation stating not to use lgraph and to use dlnetwork instead, so I attemtped it that way also and got the same error. More documentation stated not to use net=resnet50 either, and to use [net,classNames] = imagePretrainedNetwork instead. The issue is that I cannot figure out how to fit all of these pieces together. The dataset can be found here and downlaoded for free: https://www.flir.com/oem/adas/adas-dataset-form/
When the model attempts to run, it appears that it detects only 3 classes. I also used analyzeNetwork and network designer to look at the layers and it appears that the boxdeltas and R-CNN classification layers have the correct number of outputs. Any help is greatly appreciated!!
Here is the code so far (some parts generated by chatgpt and others taken from official documentation), but I have several versions of this with slight variations:
%% Define the custom read function
function imgOut = ensureRGB(imgIn)
[~, ~, numChannels] = size(imgIn);
if numChannels == 1
imgOut = repmat(imgIn, [1 1 3]);
else
imgOut = imgIn;
end
end
%% Define the paths
imageFolder = "C:\Users\User\Desktop\FLIR_Thermal_Dataset\FLIR_ADAS_v2\images_thermal_train";
annotationFolder = "C:\Users\User\Documents\MATLAB\trainingData.mat";
matFile = "C:\Users\User\Documents\MATLAB\trainingData.mat"; % MATLAB format annotations(there is a function to convert the original data into this .mat file if anyone needs it)
%% Load the training data from the MAT-file
load(matFile, 'trainingData');
% Shuffle the training data
rng(0);
shuffledIdx = randperm(height(trainingData));
trainingData = trainingData(shuffledIdx,:);
%% Create image datastore with custom read function and specify file extensions
imds = imageDatastore(trainingData.imageFilename, ...
'ReadFcn', @(filename) ensureRGB(imread(filename)), ...
'FileExtensions', {'.jpg', '.jpeg', '.png', '.bmp'});
%% Create box label datastore
blds = boxLabelDatastore(trainingData(:, {'bbox', 'label'}));
%% Combine the datastores
ds = combine(imds, blds);
%% Verify with a sample image
sampleImg = readimage(imds, 1);
[height, width, numChannels] = size(sampleImg);
disp(['Sample Image Number of Channels: ', num2str(numChannels)]);
%% Define number of classes
numClasses = 16;
%% Define input image size and anchor boxes
inputImageSize = [512 640 3];
anchorBoxes = [32 32; 64 64; 128 128];
%% Load the ResNet-50 network
lgraph = layerGraph(resnet50);
% Specify the feature extraction layer
featureLayer = 'activation_40_relu';
% Create Faster R-CNN layers
dlnetwork = fasterRCNNLayers(inputImageSize, numClasses, anchorBoxes, lgraph, featureLayer);
%% Analyze the network to ensure all layers are correct
analyzeNetwork(dlnetwork);
%% Define training options
options = trainingOptions('sgdm', ...
'MiniBatchSize', 16, ...
'InitialLearnRate', 1e-4, ...
'MaxEpochs', 10, ...
'Verbose', true, ...
'Shuffle', 'every-epoch', ...
'Plots', 'training-progress');
% Train the network
detector = trainFasterRCNNObjectDetector(ds, dlnetwork, options);
ERROR:
Training a Faster R-CNN Object Detector for the following object classes:
* car
* light
* person
Error using trainFasterRCNNObjectDetector (line 33)
Invalid network.
Error in
untitled (line 74)
detector = trainFasterRCNNObjectDetector(ds, dlnetwork, options);
Caused by:
Layer 'boxDeltas': The input size must be 1×1×12. This R-CNN box regression layer expects the third input dimension to be 4 times the number of object classes
the network should detect (3 classes). See the
documentation for more details about creating Fast or Faster R-CNN networks.
Layer 'rcnnClassification': The input size must be 1×1×4. The classification layer expects the third input dimension to be the number of object classes the
network should detect (3 classes) plus 1. The additional class is required for the "background" class. See the
documentation for more details about creating
Fast or Faster R-CNN networks.
So far I have tried:
1) using dlnetwork instead of lgraph
2) using [net,classNames] = imagePretrainedNetwork instead of net=resnet50
3) manually changing the layers in the designer
4) changing the channels from 1 to 3. (when loaded into my python environment the images had three channels, in MATLAB they showed 1)
5) resizing the images

回答(0 个)

产品


版本

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by