How to add augumentation to training andvalidation data?

1 次查看(过去 30 天)
Hi,
I am trying to resize and add noise to the training and validation images to be trained in alexnet but I keep coming up with the following error:
The value of 'ValidationData' is invalid. Invalid transform function defined on datastore.
opts = nnet.cnn.TrainingOptionsSGDM(varargin{:});
Caused by:
Too many input arguments.
How can I correct this?
Here is the code I am using:
%load network parameters
params = load("C:\Users\namit\MATLAB Drive\Individualproject\params_2023_02_13__20_34_32.mat");
%load data
digitDatasetPath = fullfile('Tomato - Copy/');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
labelCount = countEachLabel(imds)
img = readimage(imds,1);
size(img)
%Divide the data into training and validation data sets
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomize');
%resize images
augimdsTrain = augmentedImageDatastore([227 227 3],imdsValidation);
augimdsValidation = augmentedImageDatastore([227 227 3],imdsTrain);
%apply noise to images
dsTrain = transform(augimdsTrain,@preprocessForTraining, IncludeInfo=true);
dsValidation = transform(augimdsValidation,@preprocessForTraining, IncludeInfo=true);
%Define the convolutional neural network architecture.
layers = [
imageInputLayer([227 227 3],"Name","data","Mean",params.data.Mean)
convolution2dLayer([11 11],96,"Name","conv1","BiasLearnRateFactor",2,"Stride",[4 4],"Bias",params.conv1.Bias,"Weights",params.conv1.Weights)
reluLayer("Name","relu1")
crossChannelNormalizationLayer(5,"Name","norm1","K",1)
maxPooling2dLayer([3 3],"Name","pool1","Stride",[2 2])
groupedConvolution2dLayer([5 5],128,2,"Name","conv2","BiasLearnRateFactor",2,"Padding",[2 2 2 2],"Bias",params.conv2.Bias,"Weights",params.conv2.Weights)
reluLayer("Name","relu2")
crossChannelNormalizationLayer(5,"Name","norm2","K",1)
maxPooling2dLayer([3 3],"Name","pool2","Stride",[2 2])
convolution2dLayer([3 3],384,"Name","conv3","BiasLearnRateFactor",2,"Padding",[1 1 1 1],"Bias",params.conv3.Bias,"Weights",params.conv3.Weights)
reluLayer("Name","relu3")
groupedConvolution2dLayer([3 3],192,2,"Name","conv4","BiasLearnRateFactor",2,"Padding",[1 1 1 1],"Bias",params.conv4.Bias,"Weights",params.conv4.Weights)
reluLayer("Name","relu4")
groupedConvolution2dLayer([3 3],128,2,"Name","conv5","BiasLearnRateFactor",2,"Padding",[1 1 1 1],"Bias",params.conv5.Bias,"Weights",params.conv5.Weights)
reluLayer("Name","relu5")
maxPooling2dLayer([3 3],"Name","pool5","Stride",[2 2])
fullyConnectedLayer(4096,"Name","fc6")
reluLayer("Name","relu6")
dropoutLayer(0.5,"Name","drop6")
fullyConnectedLayer(4096,"Name","fc7")
reluLayer("Name","relu7")
dropoutLayer(0.5,"Name","drop7")
fullyConnectedLayer(6,"Name","fc8")
softmaxLayer("Name","prob")
classificationLayer("Name","classoutput")];
options = trainingOptions("sgdm",...
"ExecutionEnvironment","auto",...
"InitialLearnRate",0.001,...
"MaxEpochs",10,...
"Shuffle","every-epoch",...
"Plots","training-progress",...
"ValidationData",dsValidation);
%train network
net = trainNetwork(dsTrain,layers,options);
%function used
function dataOut = preprocessForTraining(data)
dataOut = data;
for idx = 1:size(data,1)
dataOut{idx} = imnoise(data{idx},'salt & pepper');
end
end

回答(1 个)

Ashu
Ashu 2023-2-21
You can investigate the following points to correct your code.
  • There is a Logical problem in your code when you are resizing the images. You can refer to the code mentioned below to correct it.
augimdsTrain = augmentedImageDatastore([227 227 3],imdsTrain);
augimdsValidation = augmentedImageDatastore([227 227 3],imdsValidation);
  • Since you have specified the value of 'IncludeInfo' to be true. In this case, the transformation function must have this signature.
function [dataOut,infoOut] = transformFcn(ds1_data,ds2_data,...dsN_data,ds1_info,ds2_info...dsN_info)
..
end
  • If you don't want to change your 'preprocessForTraining' function, you can just remove the 'IncludeInfo' argument from 'transform'
To learn more about image data augmentation, you can refer to the following link

类别

Help CenterFile Exchange 中查找有关 Image Data Workflows 的更多信息

产品

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by