- augmentedImageDatastore: https://www.mathworks.com/help/deeplearning/ref/augmentedimagedatastore.html
- imageDataAugmenter: https://www.mathworks.com/help/deeplearning/ref/imagedataaugmenter.html
- Flowers dataset: https://www.kaggle.com/datasets/imsparsh/flowers-dataset
Invalid training data. For image, sequence-to-label, and feature classification tasks, responses must be categorical.
1 次查看(过去 30 天)
显示 更早的评论
This is my code
imds = imageDatastore("Train\", ...
'IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7);
targetSize = [224,224];
% Resize the images in the training and validation sets
imdsTrainResized = transform(imdsTrain, @(x) imresize(x, targetSize));
imdsValidationResized = transform(imdsValidation, @(x) imresize(x, targetSize));
% Convert labels to categorical for each underlying ImageDatastore
imdsTrain.Labels = categorical(imdsTrain.Labels);
imdsTrainResized.UnderlyingDatastores{1,1}.Labels = categorical(imdsTrainResized.UnderlyingDatastores{1,1}.Labels);
imdsValidation.Labels = categorical(imdsValidation.Labels);
imdsValidationResized.UnderlyingDatastores{1,1}.Labels = categorical(imdsValidationResized.UnderlyingDatastores{1,1}.Labels);
% Combine the resized datastores with the original datastores
imdsTrainCombined = combine(imdsTrain, imdsTrainResized);
imdsValidationCombined = combine(imdsValidation, imdsValidationResized);
% Train the network
net = mobilenetv2('Weights','none');
miniBatchSize = 10;
valFrequency = floor(numel(imdsTrain.Files)/miniBatchSize);
options = trainingOptions('sgdm', ...
'MiniBatchSize',miniBatchSize, ...
'MaxEpochs',6, ...
'InitialLearnRate',3e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',imdsValidationCombined, ...
'ValidationFrequency',valFrequency, ...
'Verbose',false, ...
'Plots','training-progress',....
'ExecutionEnvironment', "auto");
Trained_net = trainNetwork(imdsTrainCombined, net, options);
0 个评论
回答(1 个)
Saarthak Gupta
2023-12-27
Hi Susitra,
I understand you are getting the error: “Invalid training data. For image, sequence-to-label, and feature classification tasks, responses must be categorical”.
when trying to train a MobileNetV2 network over your data.
Since the original data is not provided, I have used the “Flowers” dataset from Kaggle to reproduce the error.
It appears that there is a problem with the augmentation of the datastore, causing the responses/labels to be formatted incorrectly. To fix this problem, consider using the `augmentedImageDatastore` function from the Deep Learning Toolbox, which is designed to transform and augment the data properly.
Refer to the following code:
imds = imageDatastore('flowers\', ...
'IncludeSubfolders', true, ...
'LabelSource', 'foldernames');
% Define output image size for the network. MobileNetV2 takes inputs of
% size [224 224 3]
inputSize = [224 224];
% Split the datastore into training and validation sets.
[imdsTrain, imdsValidation] = splitEachLabel(imds, 0.7, 'randomized');
augmenter = imageDataAugmenter( ...
'RandXReflection', true, ...
'RandXTranslation', [-10 10], ...
'RandYTranslation', [-10 10], ...
'RandRotation', [-20 20], ...
'RandScale', [0.8 1.2]);
% Apply the same transformations to the training and validation sets.
augimdsTrain = augmentedImageDatastore(inputSize(1:2), imdsTrain, 'DataAugmentation', augmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2), imdsValidation);
net = mobilenetv2('Weights','none');
% Set training options.
miniBatchSize = 10;
valFrequency = floor(numel(augimdsTrain.Files) / miniBatchSize);
options = trainingOptions('sgdm', ...
'MiniBatchSize', miniBatchSize, ...
'MaxEpochs', 6, ...
'InitialLearnRate', 3e-4, ...
'Shuffle', 'every-epoch', ...
'ValidationData', augimdsValidation, ...
'ValidationFrequency', valFrequency, ...
'Verbose', false, ...
'Plots', 'training-progress', ...
'ExecutionEnvironment', "auto");
Trained_net = trainNetwork(augimdsTrain, lgraph, options);
Refer to the following MATLAB documentation for further reference:
0 个评论
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Image Data Workflows 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!