imds = imageDatastore('C:\Users\Rayan\Desktop\9_8_balance_data\R_9_1_GSM_3', ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.77,'randomized');
numTrainImages = numel(imdsTrain.Labels);
net = resnet50;
inputSize = net.Layers(1).InputSize;
lgraph = layerGraph(net);
edit(fullfile(matlabroot,'examples','nnet','main','findLayersToReplace.m'))
[learnableLayer,classLayer] = findLayersToReplace(lgraph);
numClasses = numel(categories(imdsTrain.Labels));
if isa(learnableLayer,'nnet.cnn.layer.FullyConnectedLayer')
newLearnableLayer = fullyConnectedLayer(numClasses, ...
'Name','new_fc', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
elseif isa(learnableLayer,'nnet.cnn.layer.Convolution2DLayer')
newLearnableLayer = convolution2dLayer(1,numClasses, ...
'Name','new_conv', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
end
lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);
layers = lgraph.Layers;
connections = lgraph.Connections;
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain)
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
miniBatchSize=10;
valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',60, ...
'InitialLearnRate',0.00065, ...
'Shuffle','every-epoch', ...
'ValidationFrequency',valFrequency, ...
'ValidationData',augimdsValidation, ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(augimdsTrain,lgraph,options);