Need to improve semantic segmentation using a deep network and transfer learning

2 次查看(过去 30 天)
I'm trying to use a pre-trained network to do transfer learning for image segmentation to detect different terrain types, following the example in Mutlispectral Semantic Segmentation Using Deep Learning.
I've got a pre-trained U-net based on the data from the example that is decent (accuracy in the 70-90% range).
%data and helper functions from the example page
imds = imageDatastore("train_data.mat",FileExtensions=".mat",ReadFcn=@matRead6Channels);
pxds = pixelLabelDatastore("train_labels.png",classNames,pixelLabelIds);
dsTrain = randomPatchExtractionDatastore(imds,pxds,[256,256],PatchesPerImage=1000);
dsTrain2 = randomPatchExtractionDatastore(imds,pxds,[256,256],PatchesPerImage=500); %use during transfer learning
val_ds = imageDatastore("val_data.mat",FileExtensions=".mat",ReadFcn=@matRead6Channels);
val_pxds = pixelLabelDatastore("val_labels.png",classNames,pixelLabelIds);
dsVal2 = randomPatchExtractionDatastore(val_ds,val_pxds,[256,256],PatchesPerImage=100);%use during transfer learning
lgraph = unetLayers([256,256,6], 18, 'EncoderDepth', 4);
options = trainingOptions("sgdm",...
InitialLearnRate=0.05, ...
Momentum=0.9,...
L2Regularization=0.0001,...
MaxEpochs=10,...
MiniBatchSize=8,...
LearnRateSchedule="piecewise",...
Shuffle="every-epoch",...
GradientThresholdMethod="l2norm",...
GradientThreshold=0.05, ...
Plots="training-progress", ...
VerboseFrequency=20);
net = trainNetwork(dsTrain,lgraph,options);
save("my_multispectralUnet.mat", "net");
Results from training:
To test the transfer learning part, I replace the final three layers.
data = load("my_multispectralUnet.mat");
basic_trained = data.net;
layersTransfer = basic_trained.Layers(1:end-3);
numClasses = 18; %there are 18 classes in the data
layers = [
layersTransfer
convolution2dLayer(1,numClasses, 'BiasL2Factor',0,'Padding', 'same','Name','new_Final-ConvolutionLayer', ...
'WeightLearnRateFactor', 10, 'BiasLearnRateFactor',10);
softmaxLayer('Name','new_final_softmax')
pixelClassificationLayer('Name', 'new_final_classification')];
%have to add back in the connections that make it U shaped
lgraph = connectLayers(lgraph, 'Encoder-Stage-1-ReLU-2','Decoder-Stage-4-DepthConcatenation/in2');
lgraph = connectLayers(lgraph, 'Encoder-Stage-2-ReLU-2','Decoder-Stage-3-DepthConcatenation/in2');
lgraph = connectLayers(lgraph, 'Encoder-Stage-3-ReLU-2','Decoder-Stage-2-DepthConcatenation/in2');
lgraph = connectLayers(lgraph, 'Encoder-Stage-4-DropOut','Decoder-Stage-1-DepthConcatenation/in2');
options = trainingOptions('sgdm', ...
'MiniBatchSize',8, ...
'MaxEpochs',5, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',dsVal2, ...
'ValidationFrequency',floor(100/8), ...
'Verbose',false, ...
'Plots','training-progress');
netTransfer = trainNetwork(dsTrain2,lgraph_pretraining,options);
save("trained_616.mat", "netTransfer");
I am continuing to train with patches taken from the same dataset (proof of concept), so I expect the network to be still at least as good as it was originally. But it is much worse.
Can anyone share insights as to what is going wrong? Am I specifying some parameters in the transfer part incorrectly?
Thanks
  5 个评论
Ben
Ben 2023-6-20
@Allison - I suspect changing the final convolution is causing this issue, that layer has learnable parameters that have been trained and are getting replaced with randomly initialized values.
This causes 2 issues, firstly since those weights are now just randomly initialized the network outputs will not be accurate, then secondly during training the large errors will cause large updates to the rest of the network's pre-trained weights.
That first issue is necessary if you want to perform transfer learning from one set of pixel classes to another, but the second issue can be prevented by setting the WeightLearnRateFactor and BiasLearnRateFactor for all the pre-trained layers to a small value, even 0. I expect that by setting those factors to 0 the transfer learning training should be quicker too.
Allison
Allison 2023-6-21
Thank you! I just started the training, and can already tell that freezing the transferred layers (rather than giving them very low learning rates) has helped considerably. And it is indeed going faster, too.
For future reference for folks, I used the freezeWeights support function from the example here: https://www.mathworks.com/help/deeplearning/ug/train-deep-learning-network-to-classify-new-images.html

请先登录,再进行评论。

回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Deep Learning Toolbox 的更多信息

产品


版本

R2021a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by