Need to improve semantic segmentation using a deep network and transfer learning
2 次查看(过去 30 天)
显示 更早的评论
I'm trying to use a pre-trained network to do transfer learning for image segmentation to detect different terrain types, following the example in Mutlispectral Semantic Segmentation Using Deep Learning.
I've got a pre-trained U-net based on the data from the example that is decent (accuracy in the 70-90% range).
%data and helper functions from the example page
imds = imageDatastore("train_data.mat",FileExtensions=".mat",ReadFcn=@matRead6Channels);
pxds = pixelLabelDatastore("train_labels.png",classNames,pixelLabelIds);
dsTrain = randomPatchExtractionDatastore(imds,pxds,[256,256],PatchesPerImage=1000);
dsTrain2 = randomPatchExtractionDatastore(imds,pxds,[256,256],PatchesPerImage=500); %use during transfer learning
val_ds = imageDatastore("val_data.mat",FileExtensions=".mat",ReadFcn=@matRead6Channels);
val_pxds = pixelLabelDatastore("val_labels.png",classNames,pixelLabelIds);
dsVal2 = randomPatchExtractionDatastore(val_ds,val_pxds,[256,256],PatchesPerImage=100);%use during transfer learning
lgraph = unetLayers([256,256,6], 18, 'EncoderDepth', 4);
options = trainingOptions("sgdm",...
InitialLearnRate=0.05, ...
Momentum=0.9,...
L2Regularization=0.0001,...
MaxEpochs=10,...
MiniBatchSize=8,...
LearnRateSchedule="piecewise",...
Shuffle="every-epoch",...
GradientThresholdMethod="l2norm",...
GradientThreshold=0.05, ...
Plots="training-progress", ...
VerboseFrequency=20);
net = trainNetwork(dsTrain,lgraph,options);
save("my_multispectralUnet.mat", "net");
Results from training:

To test the transfer learning part, I replace the final three layers.
data = load("my_multispectralUnet.mat");
basic_trained = data.net;
layersTransfer = basic_trained.Layers(1:end-3);
numClasses = 18; %there are 18 classes in the data
layers = [
layersTransfer
convolution2dLayer(1,numClasses, 'BiasL2Factor',0,'Padding', 'same','Name','new_Final-ConvolutionLayer', ...
'WeightLearnRateFactor', 10, 'BiasLearnRateFactor',10);
softmaxLayer('Name','new_final_softmax')
pixelClassificationLayer('Name', 'new_final_classification')];
%have to add back in the connections that make it U shaped
lgraph = connectLayers(lgraph, 'Encoder-Stage-1-ReLU-2','Decoder-Stage-4-DepthConcatenation/in2');
lgraph = connectLayers(lgraph, 'Encoder-Stage-2-ReLU-2','Decoder-Stage-3-DepthConcatenation/in2');
lgraph = connectLayers(lgraph, 'Encoder-Stage-3-ReLU-2','Decoder-Stage-2-DepthConcatenation/in2');
lgraph = connectLayers(lgraph, 'Encoder-Stage-4-DropOut','Decoder-Stage-1-DepthConcatenation/in2');
options = trainingOptions('sgdm', ...
'MiniBatchSize',8, ...
'MaxEpochs',5, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',dsVal2, ...
'ValidationFrequency',floor(100/8), ...
'Verbose',false, ...
'Plots','training-progress');
netTransfer = trainNetwork(dsTrain2,lgraph_pretraining,options);
save("trained_616.mat", "netTransfer");
I am continuing to train with patches taken from the same dataset (proof of concept), so I expect the network to be still at least as good as it was originally. But it is much worse.

Can anyone share insights as to what is going wrong? Am I specifying some parameters in the transfer part incorrectly?
Thanks
5 个评论
Ben
2023-6-20
@Allison - I suspect changing the final convolution is causing this issue, that layer has learnable parameters that have been trained and are getting replaced with randomly initialized values.
This causes 2 issues, firstly since those weights are now just randomly initialized the network outputs will not be accurate, then secondly during training the large errors will cause large updates to the rest of the network's pre-trained weights.
That first issue is necessary if you want to perform transfer learning from one set of pixel classes to another, but the second issue can be prevented by setting the WeightLearnRateFactor and BiasLearnRateFactor for all the pre-trained layers to a small value, even 0. I expect that by setting those factors to 0 the transfer learning training should be quicker too.
回答(0 个)
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Deep Learning Toolbox 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!