Main Content

Semantic Segmentation Using Dilated Convolutions

Train a semantic segmentation network using dilated convolutions.

A semantic segmentation network classifies every pixel in an image, resulting in an image that is segmented by class. Applications for semantic segmentation include road segmentation for autonomous driving and cancer cell segmentation for medical diagnosis. To learn more, see Getting Started with Semantic Segmentation Using Deep Learning.

Semantic segmentation networks like Deeplab v3+ [1] make extensive use of dilated convolutions (also known as atrous convolutions) because they can increase the receptive field of the layer (the area of the input which the layers can see) without increasing the number of parameters or computations.

Load Training Data

The example uses a simple dataset of 32-by-32 triangle images for illustration purposes. The dataset includes accompanying pixel label ground truth data. Load the training data using an imageDatastore and a pixelLabelDatastore.

dataFolder = fullfile(toolboxdir("vision"),"visiondata","triangleImages");
imageFolderTrain = fullfile(dataFolder,"trainingImages");
labelFolderTrain = fullfile(dataFolder,"trainingLabels");

Create an imageDatastore for the images.

imdsTrain = imageDatastore(imageFolderTrain);

Create a pixelLabelDatastore for the ground truth pixel labels.

classNames = ["triangle" "background"];
labels = [255 0];
pxdsTrain = pixelLabelDatastore(labelFolderTrain,classNames,labels)
pxdsTrain = 
  PixelLabelDatastore with properties:

                       Files: {200×1 cell}
                  ClassNames: {2×1 cell}
                    ReadSize: 1
                     ReadFcn: @readDatastoreImage
    AlternateFileSystemRoots: {}

Create Semantic Segmentation Network

This example uses a simple semantic segmentation network based on dilated convolutions.

Create a data source for training data and get the pixel counts for each label.

ds = combine(imdsTrain,pxdsTrain);
tbl = countEachLabel(pxdsTrain)
tbl=2×3 table
         Name         PixelCount    ImagePixelCount
    ______________    __________    _______________

    {'triangle'  }         10326       2.048e+05   
    {'background'}    1.9447e+05       2.048e+05   

The majority of pixel labels are for background. This class imbalance biases the learning process in favor of the dominant class. To fix this, use class weighting to balance the classes. You can use several methods to compute class weights. One common method is inverse frequency weighting where the class weights are the inverse of the class frequencies. This method increases the weight given to under represented classes. Calculate the class weights using inverse frequency weighting.

numberPixels = sum(tbl.PixelCount);
frequency = tbl.PixelCount / numberPixels;
classWeights = dlarray(1 ./ frequency,"C");

Create a network for pixel classification by using an image input layer with an input size corresponding to the size of the input images. Next, specify three blocks of convolution, batch normalization, and ReLU layers. For each convolutional layer, specify 32 3-by-3 filters with increasing dilation factors and pad the inputs so they are the same size as the outputs by setting the Padding name-value argument as "same". To classify the pixels, include a convolutional layer with K 1-by-1 convolutions, where K is the number of classes, followed by a softmax layer. The classification of pixels is done with a custom model loss within the built-in trainer, trainnet.

inputSize = [32 32 1];
filterSize = 3;
numFilters = 32;
numClasses = numel(classNames);

layers = [
    imageInputLayer(inputSize)
    
    convolution2dLayer(filterSize,numFilters,DilationFactor=1,Padding="same")
    batchNormalizationLayer
    reluLayer
    
    convolution2dLayer(filterSize,numFilters,DilationFactor=2,Padding="same")
    batchNormalizationLayer
    reluLayer
    
    convolution2dLayer(filterSize,numFilters,DilationFactor=4,Padding="same")
    batchNormalizationLayer
    reluLayer
    
    convolution2dLayer(1,numClasses)
    softmaxLayer];

Model Loss Function

The semantic segmentation network can be trained using different loss functions. The built-in trainer trainnet (Deep Learning Toolbox) supports custom loss functions as well as some standard loss functions such as "crossentropy" and "mse". A custom loss function manually computes the loss for each batch of training data by comparing the network's predictions to the actual ground truth or target values. Custom loss functions use a function handle with the function syntax loss = f(Y1,...,Yn,T1,...,Tm), where Y1,...,Yn are dlarray objects that correspond to the n network predictions and T1,...,Tm are dlarray objects that correspond to the m targets.

This example enables you to select from two different loss functions that account for the class imbalance seen in the data. These loss functions are:

  1. Weighted cross-entropy loss, which uses the crossentropy (Deep Learning Toolbox) function. Weighted cross-entropy loss gives stronger favor to the underrepresented class by scaling the error of that class during training.

  2. A custom loss function called tverskyLoss that calculates the Tversky loss [2]. Tversky loss is more specialized loss for class imbalance.

The Tversky loss is based on the Tversky index for measuring overlap between two segmented images. The Tversky index TIc between one image Y and the corresponding ground truth T is given by

TIc=m=1MYcmTcmm=1MYcmTcm+αm=1MYcmTcm+βm=1MYcmTcm

  • c corresponds to the class and ccorresponds to not being in class c.

  • M is the number of elements along the first two dimensions of Y.

  • α and β are weighting factors that control the contribution that false positives and false negatives for each class make to the loss.

The loss Lover the number of classes C is given by

L=c=1C1-TIc

Select the loss function to use during training.

lossFunction = "tverskyLoss"
lossFunction = 
"tverskyLoss"
if strcmp(lossFunction,"tverskyLoss")
    % Declare Tversky loss weighting coefficients for false positives and
    % false negatives. These coefficients are set and passed to the
    % training loss function using trainnet.
    alpha = 0.7;
    beta = 0.3;
    lossFcn = @(Y,T) tverskyLoss(Y,T,alpha,beta);
else
    % Use weighted cross-entropy loss during training.
    lossFcn = @(Y,T) crossentropy(Y,T,classWeights,NormalizationFactor="all-elements");
end

Train Network

Specify the training options.

options = trainingOptions("sgdm",...
    MaxEpochs=100,...
    MiniBatchSize= 64,... 
    InitialLearnRate=1e-2,...
    Verbose=false);

Train the network using trainnet (Deep Learning Toolbox). Specify the loss as the loss function lossFcn.

net = trainnet(ds,layers,lossFcn,options);

Test Network

Load the test data. Create an imageDatastore for the images. Create a pixelLabelDatastore for the ground truth pixel labels.

imageFolderTest = fullfile(dataFolder,"testImages");
imdsTest = imageDatastore(imageFolderTest);
labelFolderTest = fullfile(dataFolder,"testLabels");
pxdsTest = pixelLabelDatastore(labelFolderTest,classNames,labels);

Make predictions using the test data and trained network.

pxdsPred = semanticseg(imdsTest,net,...
    Classes=classNames,...
    MiniBatchSize=32,...
    WriteLocation=tempdir);
Running semantic segmentation network
-------------------------------------
* Processed 100 images.

Evaluate the prediction accuracy using evaluateSemanticSegmentation.

metrics = evaluateSemanticSegmentation(pxdsPred,pxdsTest);
Evaluating semantic segmentation results
----------------------------------------
* Selected metrics: global accuracy, class accuracy, IoU, weighted IoU, BF score.
* Processed 100 images.
* Finalizing... Done.
* Data set metrics:

    GlobalAccuracy    MeanAccuracy    MeanIoU    WeightedIoU    MeanBFScore
    ______________    ____________    _______    ___________    ___________

       0.99674          0.98562       0.96447      0.99362        0.92831  

Segment New Image

Read the test image triangleTest.jpg and segment the test image using semanticseg. Display the results using labeloverlay.

imgTest = imread("triangleTest.jpg");
[C,scores] = semanticseg(imgTest,net,classes=classNames);

B = labeloverlay(imgTest,C);
montage({imgTest,B})

Figure contains an axes object. The axes object contains an object of type image.

Supporting Functions

function loss = tverskyLoss(Y,T,alpha,beta)
    % loss = tverskyLoss(Y,T,alpha,beta) returns the Tversky loss
    % between the predictions Y and the training targets T.   
    
    Pcnot = 1-Y;
    Gcnot = 1-T;
    TP = sum(sum(Y.*T,1),2);
    FP = sum(sum(Y.*Gcnot,1),2);
    FN = sum(sum(Pcnot.*T,1),2); 
    
    epsilon = 1e-8;
    numer = TP + epsilon;
    denom = TP + alpha*FP + beta*FN + epsilon;
    
    % Compute tversky index.
    lossTIc = 1 - numer./denom;
    lossTI = sum(lossTIc,3);
    
    % Return average Tversky index loss.
    N = size(Y,4);
    loss = sum(lossTI)/N;
end

References

[1] Chen, Liang-Chieh et al. “Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation.” ECCV (2018).

[2] Salehi, Seyed Sadegh Mohseni, Deniz Erdogmus, and Ali Gholipour. "Tversky loss function for image segmentation using 3D fully convolutional deep networks." International Workshop on Machine Learning in Medical Imaging. Springer, Cham, 2017.