Main Content

Train a Semantic Segmentation Network

Load the training data.

dataSetDir = fullfile(toolboxdir("vision"),"visiondata","triangleImages");
imageDir = fullfile(dataSetDir,"trainingImages");
labelDir = fullfile(dataSetDir,"trainingLabels");

Create an image datastore for the images.

imds = imageDatastore(imageDir);

Create a pixelLabelDatastore for the ground truth pixel labels.

classNames = ["triangle" "background"];
labelIDs   = [255 0];
pxds = pixelLabelDatastore(labelDir,classNames,labelIDs);

Visualize training images and ground truth pixel labels.

I = read(imds);
C = read(pxds);

I = imresize(I,5,"nearest");
L = imresize(uint8(C{1}),5,"nearest");
imshowpair(I,L,"montage")

Figure contains an axes object. The axes object contains an object of type image.

Combine the image and pixel label datastore for training.

trainingData = pixelLabelImageDatastore(imds,pxds);

Create a semantic segmentation network. This network uses a simple semantic segmentation network based on a downsampling and upsampling design.

numFilters = 64;
filterSize = 3;
numClasses = 2;
layers = [
    imageInputLayer([32 32 1])
    convolution2dLayer(filterSize,numFilters,Padding=1)
    reluLayer()
    maxPooling2dLayer(2,Stride=2)
    convolution2dLayer(filterSize,numFilters,Padding=1)
    reluLayer()
    transposedConv2dLayer(4,numFilters,Stride=2,Cropping=1);
    convolution2dLayer(1,numClasses);
    softmaxLayer()
    ];

Setup training options.

opts = trainingOptions("sgdm", ...
    InitialLearnRate=1e-3, ...
    MaxEpochs=100, ...
    MiniBatchSize=64);

Define a loss function suitable for pixel classification.

function loss = modelLoss(Y,T)
    mask = ~isnan(T);
    T(isnan(T)) = 0;
    loss = crossentropy(Y,T,Mask=mask,NormalizationFactor="mask-included");
end

Train the network.

net = trainnet(trainingData,layers,@modelLoss,opts);
    Iteration    Epoch    TimeElapsed    LearnRate    TrainingLoss
    _________    _____    ___________    _________    ____________
            1        1       00:00:06        0.001          41.892
           50       17       00:00:21        0.001         0.93931
          100       34       00:00:35        0.001          0.7432
          150       50       00:00:52        0.001          0.4558
          200       67       00:01:10        0.001         0.48874
          250       84       00:01:32        0.001         0.43741
          300      100       00:01:44        0.001         0.32055
Training stopped: Max epochs completed

Read and display a test image.

testImage = imread("triangleTest.jpg");
imshow(testImage)

Figure contains an axes object. The axes object contains an object of type image.

Segment the test image and display the results.

C = semanticseg(testImage,net,Classes=classNames);
B = labeloverlay(testImage,C);
imshow(B)

Figure contains an axes object. The axes object contains an object of type image.

See Also

(Deep Learning Toolbox)

Related Topics