Main Content

Prune and Quantize Semantic Segmentation Network

This example shows how to reduce the memory footprint of a semantic segmentation network and speed up inference by compressing the network using pruning and quantization.

Download Pretrained Semantic Segmentation Network

Download a pretrained version of DeepLab v3+ trained on the CamVid data set [1, 2]. For more information about this semantic segmentation network and the data set, see Semantic Segmentation Using Deep Learning.

pretrainedURL = "https://ssd.mathworks.com/supportfiles/vision/data/deeplabv3plusResnet18CamVid_v2.zip";
pretrainedFolder = fullfile(tempdir,"pretrainedNetwork");
pretrainedNetworkZip = fullfile(pretrainedFolder,"deeplabv3plusResnet18CamVid_v2.zip"); 
if ~exist(pretrainedNetworkZip,"file")
    mkdir(pretrainedFolder);
    disp("Downloading pretrained network (58 MB)...")
    websave(pretrainedNetworkZip,pretrainedURL);
end
Downloading pretrained network (58 MB)...
unzip(pretrainedNetworkZip,pretrainedFolder)

Load the pretrained network.

pretrainedNetwork = fullfile(pretrainedFolder,"deeplabv3plusResnet18CamVid_v2.mat");  
data = load(pretrainedNetwork);
trainedNet = data.net;

List the classes that this network can classify.

classes = getClassNames()
classes = 11×1 string
    "Sky"
    "Building"
    "Pole"
    "Road"
    "Pavement"
    "Tree"
    "SignSymbol"
    "Fence"
    "Car"
    "Pedestrian"
    "Bicyclist"

Download CamVid Data Set

Download the CamVid data set [2].

imageURL = "http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip";
labelURL = "http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip";
 
outputFolder = fullfile(tempdir,"CamVid"); 
labelsZip = fullfile(outputFolder,"labels.zip");
imagesZip = fullfile(outputFolder,"images.zip");

if ~exist(labelsZip,"file") || ~exist(imagesZip,"file")   
    mkdir(outputFolder)
       
    disp("Downloading 16 MB CamVid data set labels...")
    websave(labelsZip,labelURL);
    unzip(labelsZip,fullfile(outputFolder,"labels"));
    
    disp("Downloading 557 MB CamVid data set images...")
    websave(imagesZip,imageURL);       
    unzip(imagesZip,fullfile(outputFolder,"images"));    
end

Load CamVid Pixel-Labeled Images

Use a pixelLabelDatastore object to load CamVid pixel label image data. A pixelLabelDatastore object encapsulates the pixel label data and the label ID to a class name mapping.

To simplify training, group the 32 original classes in CamVid into 11 classes matching the classes that the pretrained DeepLab v3+ network can classify. For example, "Car" is a combination of the CamVid "Car", "SUVPickupTruck", "Truck_Bus", "Train", and "OtherMoving" classes. Return the grouped label IDs by using the camvidPixelLabelIDs helper function, which is defined at the end of this example.

labelIDs = camvidPixelLabelIDs;

Create the pixelLabelDatastore using the classes and label IDs.

imgDir = fullfile(outputFolder,"images","701_StillsRaw_full");
imds = imageDatastore(imgDir);

labelDir = fullfile(outputFolder,"labels");
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);

Read and display one of the pixel-labeled images by overlaying it on top of an image.

I = readimage(imds,559);
C = readimage(pxds,559);
cmap = camvidColorMap;
B = labeloverlay(I,C,ColorMap=cmap);
imshow(B)
pixelLabelColorbar(cmap,classes);

Prepare Data for Training

Randomly split the image and pixel label data into training, validation, and test sets. Allocate 70% of the images from the data set to train the Deeplab v3+ model. Allocate 10% of the data for validation and the remaining 20% for testing.

[imdsTrain,imdsVal,imdsTest,pxdsTrain,pxdsVal,pxdsTest] = partitionCamVidData(imds,pxds);
dsTrain = combine(imdsTrain,pxdsTrain);

Display the number of training, validation, and test images.

numTrainingImages = numel(imdsTrain.Files)
numTrainingImages = 491
numValImages = numel(imdsVal.Files)
numValImages = 70
numTestingImages = numel(imdsTest.Files)
numTestingImages = 140

Evaluate Network Before Compression

Evaluate the performance of the network by using the evaluateNet helper function, which is defined at the end of this example. First, the evaluateNet helper function performs semantic segmentation of the test images. Then, the function calculates metrics that evaluate the quality of the semantic segmentation results against the ground truth segmentation.

trainedNetMetrics = evaluateNet(trainedNet,imdsTest,pxdsTest,classes);
trainedNetMetrics.DataSetMetrics
ans=1×5 table
    GlobalAccuracy    MeanAccuracy    MeanIoU    WeightedIoU    MeanBFScore
    ______________    ____________    _______    ___________    ___________

       0.90852          0.88877       0.69871      0.85013        0.74106  

Prune Network

Create Prunable Network

Create a prunable object based on first-order Taylor approximation by using the taylorPrunableNetwork (Deep Learning Toolbox) function. A TaylorPrunableNetwork object has similar properties and object functions to a dlnetwork object. However, a TaylorPrunableNetwork object also has pruning-specific properties and object functions. You can use a TaylorPrunableNetwork object instead of a dlnetwork object in a custom training loop. Pruning is iterative, which means that each time the loop runs, the function removes a small number of the least important convolution filters and updates the network architecture. This process continues until a stopping criterion is met.

prunableNet = taylorPrunableNetwork(trainedNet)
prunableNet = 
  TaylorPrunableNetwork with properties:

      Learnables: [118×3 table]
           State: [56×3 table]
      InputNames: {'data'}
     OutputNames: {'softmax-out'}
    NumPrunables: 5333

Specify Pruning Options

Set the pruning options.

  • maxPruningIterations defines the maximum number of iterations in the pruning loop.

  • maxToPrune is the maximum number of filters to prune in each iteration of the pruning loop.

  • validationFrequency is the number of iterations to wait before validating the pruned network using the test data.

maxPruningIterations = 30;
maxToPrune = 64;
validationFrequency = 5;

Set the fine-tuning options.

  • maxMinibatchIterations defines the maximum number of iterations in the fine-tuning loop.

  • Specify the options for stochastic gradient descent with momentum (SGDM) optimization. Specify an initial learning rate of 0.001 and a momentum of 0.9.

  • Specify a mini-batch size of 8 to fine-tune the network.

maxMinibatchIterations = 40;
learnRate = 0.001;
momentum = 0.9;
miniBatchSize = 8;

Create a figure to monitor the loss, validation accuracy, and number of prunable filters during training.

monitor = trainingProgressMonitor(Metrics=["Loss" "ValAccuracy" "NumPrunables"], ...
    XLabel="Iteration");

Create Mini-batch Queue

Use a minibatchqueue object to process and manage the mini-batches of images. For each mini-batch, perform these steps:

  • Separate the image and label data using the deal function.

  • Format the image and label data with the dimension labels "SSCB" (spatial, spatial, channel, batch).

  • Train on a GPU if one is available. By default, the minibatchqueue object converts each output to a gpuArray if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).

mbqTrain = minibatchqueue(dsTrain, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@(img,labels) deal(cat(4,img{:}),cat(4,labels{:})), ...
    OutputAsDlarray=[1 1], ...
    MiniBatchFormat=["SSCB" "SSCB"], ...
    OutputEnvironment=["auto" "cpu"]);

Prune Network Using Custom Training Loop

Prune the network by repeatedly fine-tuning the network and removing the low scoring filters. For each pruning iteration, perform these operations:

  • Fine-tune the network and accumulate Taylor scores for convolution filters for maxMinibatchIterations iterations.

  • Remove a small number of the least important convolution filters and update the network architecture using the updatePrunables (Deep Learning Toolbox) function.

  • Display the training progress.

To fine tune the network, loop over the mini-batches of the training data. For each mini-batch in the fine-tuning iteration, perform these operations:

  • Calculate the pruning activations, gradients of the pruning activations, model gradients, state, and loss using the dlfeval (Deep Learning Toolbox) and modelGradients functions. The modelGradients function is a helper function that is defined at the end of the example.

  • Update the network state.

  • Update the network learnable parameters by using the sgdmupdate (Deep Learning Toolbox) function.

  • Calculate the first-order Taylor scores and accumulate the score across previous mini-batches of data by using the updateScore (Deep Learning Toolbox) function.

pruningIteration = 1;
metrics = 0;

% Pruning loop
while (prunableNet.NumPrunables > maxToPrune) && (pruningIteration < maxPruningIterations)
 
    % Reset and shuffle the mini-batch
    reset(mbqTrain);
    shuffle(mbqTrain);
 
    % Reset the parameters for the current pruning iteration
    velocity = [];
    localIteration = 0;

    % Fine-tuning loop
    while hasdata(mbqTrain)
 
        localIteration = localIteration + 1;
 
        [dlX,Y] = next(mbqTrain);
 
        % Calculate activations of masking layers and gradient of loss with respect to these activations
        [dLearnables,dGatingLayers,gatingLayerOuts,state,loss] = dlfeval( ...
            @modelGradients,prunableNet,dlX,Y);
        prunableNet.State = state;
        
        [prunableNet,velocity] = sgdmupdate(prunableNet,dLearnables,velocity,learnRate,momentum);
 
        prunableNet = updateScore(prunableNet,dGatingLayers,gatingLayerOuts);
 
        if (localIteration > maxMinibatchIterations)
            break
        end
 
    end % End fine-tuning loop
   
    prunableNet = updatePrunables(prunableNet,MaxToPrune=maxToPrune);

    if (mod(pruningIteration,validationFrequency) == 0 || pruningIteration==1)      
        metrics=evaluateNet(prunableNet,imdsVal,pxdsVal,classes);
    end

    pruningIteration = pruningIteration + 1;

    recordMetrics(monitor,pruningIteration,Loss=loss, ...
        ValAccuracy=metrics.DataSetMetrics.WeightedIoU,NumPrunables=prunableNet.NumPrunables);
    
end % End pruning loop

During each pruning iteration, the validation accuracy often decreases because of changes in the network structure when the software prunes the convolutional filters. To minimize loss accuracy, retrain the network after pruning.

When pruning is complete, convert the TaylorPrunableNetwork object back to a dlnetwork object for retraining and further analysis.

prunedNet = dlnetwork(prunableNet);
save("dlnet_pruned.mat","prunedNet");

Compare Filters in Original Network and Pruned Network

Determine the impact of pruning on each layer.

originalNetFilters = numConvLayerFilters(trainedNet);
prunedNetFilters = numConvLayerFilters(prunedNet);
convFilters = join(originalNetFilters,prunedNetFilters,Keys="Row");

Visualize the number of filters in the original network and in the pruned network.

figure(Position=[10,10,900,900])
bar([convFilters.(1),convFilters.(2)])
xlabel("Layer")
ylabel("Number of Filters")
title("Number of Filters Per Layer")
xticks(1:(numel(convFilters.Row)))
xticklabels(convFilters.Row)
xtickangle(90)
ax = gca;
ax.TickLabelInterpreter = "none";
legend("Original Network Filters","Pruned Network Filters",Location="southoutside")

Evaluate Pruned Network

Evaluate the pruned network using the test set. The accuracy scores of the pruned network are much lower than the accuracy scores of the original network. You can recover the lost accuracy by retraining the pruned network.

prunedNetMetrics = evaluateNet(prunedNet,imdsTest,pxdsTest,classes);
prunedNetMetrics.DataSetMetrics
ans=1×5 table
    GlobalAccuracy    MeanAccuracy    MeanIoU    WeightedIoU    MeanBFScore
    ______________    ____________    _______    ___________    ___________

       0.92644          0.75779       0.68834      0.86613        0.75825  

Retrain Pruned Network

The pruning process can cause the prediction accuracy to decrease. Try to improve the prediction accuracy by retraining the network using a custom training loop.

Specify Training Options

Specify the options to use during retraining.

  • Specify the options for SGDM optimization. Specify an initial learning rate of 0.001 and a momentum of 0.9. Initialize the gradient velocity as [].

learnRate = 0.001;
momentum = 0.9;
valFreq = 50;
numEpochs = 8;
velocity = [];

Create a figure to monitor the loss and validation accuracy during training.

monitor = trainingProgressMonitor(Metrics=["Loss" "ValAccuracy"], ...
    Info=["Epoch" "LearnRate"],XLabel="Iteration");

Train Pruned Network Using Custom Training Loop

Train the network in a custom training loop. For each iteration:

  • Evaluate the model gradients using the dlfeval (Deep Learning Toolbox) function and the modelLoss helper function, which is defined at the end of the example.

  • Update the network parameters using the sgdmupdate (Deep Learning Toolbox) function.

  • Display the progress.

epoch = 0;
iteration = 0;
metrics = [];

% Loop over epochs
while (epoch < numEpochs) && ~monitor.Stop
    
    epoch = epoch + 1;

    % Shuffle data
    shuffle(mbqTrain);
    
    % Loop over mini-batches
    while hasdata(mbqTrain) && ~monitor.Stop

        iteration = iteration + 1;
        
        % Read mini-batch of data
        [X,T] = next(mbqTrain);
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelLoss function
        [loss,gradients,state] = dlfeval(@modelLoss,prunedNet,X,T);

        % Update the network state
        prunedNet.State = state;

        % Update the network parameters using the SGDM optimizer
        [prunedNet,velocity] = sgdmupdate(prunedNet,gradients,velocity,learnRate,momentum);

        if (mod(iteration,valFreq) == 0 || iteration == 1)      
            reset(imdsVal);
            reset(pxdsVal);
            metrics=evaluateNet(prunedNet,imdsVal,pxdsVal,classes,iteration);
        end

        % Update the training progress monitor
        recordMetrics(monitor,iteration,Loss=loss,ValAccuracy=metrics.DataSetMetrics.MeanIoU);
        updateInfo(monitor,Epoch=epoch,LearnRate=learnRate);
        
    end
end

save("dlnet_pruned_retrained.mat","prunedNet");

Evaluate Retrained Pruned Network

Evaluate the metrics for the retrained pruned network.

prunedNetMetrics = evaluateNet(prunedNet,imdsTest,pxdsTest,classes);
prunedNetMetrics.DataSetMetrics
ans=1×5 table
    GlobalAccuracy    MeanAccuracy    MeanIoU    WeightedIoU    MeanBFScore
    ______________    ____________    _______    ___________    ___________

       0.93333          0.80252       0.72794      0.87854        0.79206  

Compare the memory footprint and accuracy of the original and pruned networks. Using the mean accuracy metric, the pruned network uses 54% less memory than the original network, with a moderate decrease in accuracy.

statsPruned = compareNetworkMetrics(trainedNet,prunedNet, ...
    trainedNetMetrics.DataSetMetrics.MeanAccuracy,prunedNetMetrics.DataSetMetrics.MeanAccuracy, ...
    "Pruned Network")
statsPruned=3×3 table
                         Network Learnables    Approx. Network Memory (MB)    MeanAccuracy
                         __________________    ___________________________    ____________

    Original Network         1.6402e+07                   62.568                0.88877   
    Pruned Network            7.556e+06                   28.824                0.80252   
    Percentage Change           -53.932                  -53.932                -9.7047   

Quantize Network

Quantize the retrained pruned network for a GPU target. Quantization reduces the memory footprint of the network by converting weights, biases, and activations of convolution layers from floating-point data types to 8-bit scaled integer data types. After quantization, a network can perform inference more quickly.

To improve the performance of the network after quantization, equalize the layer parameters of the retrained pruned network by using the equalizeLayers function.

eqNet = equalizeLayers(prunedNet);

Create a quantizable version of the retrained pruned network by using a dlquantizer object. Specify a GPU target using the ExecutionEnvironment name-value argument.

quantizableNet = dlquantizer(prunedNet,ExecutionEnvironment="GPU");

Calibrate the network with the training data by using the calibrate function. Calibration consists of exercising the network with sample inputs and collecting dynamic range information.

calibrate(quantizableNet,dsTrain,MiniBatchSize=8);

Quantize the network object and return a simulatable quantized network by using the quantize function.

quantizedNet = quantize(quantizableNet,ExponentScheme="Histogram");

Display the details of the quantized network by using the quantizationDetails function.

qDetails = quantizationDetails(quantizedNet)
qDetails = struct with fields:
            IsQuantized: 1
          TargetLibrary: "cudnn"
    QuantizedLayerNames: [91×1 string]
    QuantizedLearnables: [51×3 table]

Compare Original and Quantized Networks

Calculate metrics for the quantized semantic segmentation network.

quantizedNetMetrics = evaluateNet(quantizedNet,imdsTest,pxdsTest,classes);
quantizedNetMetrics.DataSetMetrics
ans=1×5 table
    GlobalAccuracy    MeanAccuracy    MeanIoU    WeightedIoU    MeanBFScore
    ______________    ____________    _______    ___________    ___________

       0.92221          0.75959       0.68331      0.85938        0.76537  

Compare the original network and the quantized network. The quantized network uses 88.5% less memory than the original network, with a decrease in accuracy.

statsQuantized = compareNetworkMetrics(trainedNet,quantizedNet, ...
    trainedNetMetrics.DataSetMetrics.MeanAccuracy,quantizedNetMetrics.DataSetMetrics.MeanAccuracy, ...
    "Quantized Network")
statsQuantized=3×3 table
                         Network Learnables    Approx. Network Memory (MB)    MeanAccuracy
                         __________________    ___________________________    ____________

    Original Network         1.6402e+07                   62.568                0.88877   
    Quantized Network         7.556e+06                    7.206                0.75959   
    Percentage Change           -53.932                  -88.483                -14.535   

Results Summary

Compare the number of learnables, memory footprint, and accuracy of the original, pruned, and quantized networks.

[statsPruned(1:2,:); statsQuantized(2,:)]
ans=3×3 table
                         Network Learnables    Approx. Network Memory (MB)    MeanAccuracy
                         __________________    ___________________________    ____________

    Original Network         1.6402e+07                  62.568                 0.88877   
    Pruned Network            7.556e+06                  28.824                 0.80252   
    Quantized Network         7.556e+06                   7.206                 0.75959   

Helper Functions

function [loss,gradients,state] = modelLoss(net,X,T)
% Calculate semantic segmentation model loss.

% Forward data through network
[Y,state] = forward(net,X);

T = extractdata(T);
T_Onehotencode = onehotencode(T,3,ClassNames=1:11);
T_Onehotencode(isnan(T_Onehotencode)) = 0;

numObs = size(Y,1) * size(Y,2) * size(Y,4);

% Calculate cross-entropy loss
loss = crossentropy(Y,T_Onehotencode)/numObs;

% Calculate gradients of loss with respect to learnable parameters
gradients = dlgradient(loss,net.Learnables);

end
function [dLossdLearnables,pruningGradient,pruningActivations,state,loss] = modelGradients(networkPruner,dlX,Y)
% Calculate network pruning model gradients.

[networkAct,state,pruningActivations] = forward(networkPruner,dlX);

% Get the output of softmax and use it to compute loss
Y = extractdata(Y);
Y2 = onehotencode(Y,3,ClassNames=1:11);
Y2(isnan(Y2)) = 0;

Y2 = dlarray(Y2,"SSCB");
dims = size(networkAct);
bz = dims(end);
loss = crossentropy(networkAct,Y2)/(prod(dims(1:2))*bz);

% Retrieve outputs of gating layers and network outputs
% differentiate loss w.r.t learnables and gating layers
[dLossdLearnables,pruningGradient] = dlgradient(loss,networkPruner.Learnables,pruningActivations);

end
function convFilters = numConvLayerFilters(net)
% Return the number of filters in each convolution layer of a network. 

numLayers = numel(net.Layers);
convNames = [];
numFilters = [];
% Check for convolution layers and extract the number of filters.
for cnt = 1:numLayers
    if isa(net.Layers(cnt),"nnet.cnn.layer.Convolution2DLayer")
        sizeW = size(net.Layers(cnt).Weights);
        numFilters = [numFilters; sizeW(end)]; %#ok<AGROW>
        convNames = [convNames; string(net.Layers(cnt).Name)]; %#ok<AGROW>
    end
end
convFilters = table(numFilters,RowNames=convNames);
end
function ssm = evaluateNet(net,imds,pxdsTruth,classNames,itr)
% Apply semantic segmentation and evalate the segmentation results.

dirname = tempdir;
if nargin==5
   dirname = dirname + "val_" + num2str(itr);
   mkdir(dirname);
end   

pxdsResults = semanticseg(imds,net,WriteLocation=dirname,Classes=classNames,Verbose=false,MinibatchSize=8);
ssm = evaluateSemanticSegmentation(pxdsResults,pxdsTruth,Verbose=false);

end
function statistics = compareNetworkMetrics(originalNet,compressedNet, ...
    orginalNetAccuracy,compressedNetAccuracy,compressedNetName)
% Return statistics about a network including the type of network, number of learnables, size, and accuracy.

originalNetMetrics = estimateNetworkMetrics(originalNet);
prunedNetMetrics = estimateNetworkMetrics(compressedNet);

% Accuracy of original network and pruned network
perChangeAccu = 100*(compressedNetAccuracy - orginalNetAccuracy)/orginalNetAccuracy;
accuracyForNetworks = [orginalNetAccuracy;compressedNetAccuracy;perChangeAccu];

% Total learnables in both networks
originalNetLearnables = sum(originalNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables);
prunedNetLearnables = sum(prunedNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables);
percentageChangeLearnables = 100*(prunedNetLearnables - originalNetLearnables)/originalNetLearnables;
learnablesForNetwork = [originalNetLearnables;prunedNetLearnables;percentageChangeLearnables];

% Approximate parameter memory
approxOriginalMemory = sum(originalNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)"));
approxPrunedMemory = sum(prunedNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)"));
percentageChangeMemory = 100*(approxPrunedMemory - approxOriginalMemory)/approxOriginalMemory;
networkMemory = [approxOriginalMemory; approxPrunedMemory; percentageChangeMemory];

% Create the summary table
statistics = table(learnablesForNetwork,networkMemory,accuracyForNetworks, ...
    VariableNames=["Network Learnables","Approx. Network Memory (MB)","MeanAccuracy"], ...
    RowNames=["Original Network",compressedNetName,"Percentage Change"]);

end
function labelIDs = camvidPixelLabelIDs()
% Return the label IDs corresponding to each class.
%
% The CamVid data set has 32 classes. Group them into 11 classes following
% the original SegNet training methodology [1].
%
% The 11 classes are:
%   "Sky" "Building", "Pole", "Road", "Pavement", "Tree", "SignSymbol",
%   "Fence", "Car", "Pedestrian",  and "Bicyclist".
%
% CamVid pixel label IDs are provided as RGB color values. Group them into
% 11 classes and return them as a cell array of M-by-3 matrices. The
% original CamVid class names are listed alongside each RGB value. Note
% that the Other/Void class are excluded below.
labelIDs = { ...
    
    % "Sky"
    [
    128 128 128; ... % "Sky"
    ]
    
    % "Building" 
    [
    000 128 064; ... % "Bridge"
    128 000 000; ... % "Building"
    064 192 000; ... % "Wall"
    064 000 064; ... % "Tunnel"
    192 000 128; ... % "Archway"
    ]
    
    % "Pole"
    [
    192 192 128; ... % "Column_Pole"
    000 000 064; ... % "TrafficCone"
    ]
    
    % Road
    [
    128 064 128; ... % "Road"
    128 000 192; ... % "LaneMkgsDriv"
    192 000 064; ... % "LaneMkgsNonDriv"
    ]
    
    % "Pavement"
    [
    000 000 192; ... % "Sidewalk" 
    064 192 128; ... % "ParkingBlock"
    128 128 192; ... % "RoadShoulder"
    ]
        
    % "Tree"
    [
    128 128 000; ... % "Tree"
    192 192 000; ... % "VegetationMisc"
    ]
    
    % "SignSymbol"
    [
    192 128 128; ... % "SignSymbol"
    128 128 064; ... % "Misc_Text"
    000 064 064; ... % "TrafficLight"
    ]
    
    % "Fence"
    [
    064 064 128; ... % "Fence"
    ]
    
    % "Car"
    [
    064 000 128; ... % "Car"
    064 128 192; ... % "SUVPickupTruck"
    192 128 192; ... % "Truck_Bus"
    192 064 128; ... % "Train"
    128 064 064; ... % "OtherMoving"
    ]
    
    % "Pedestrian"
    [
    064 064 000; ... % "Pedestrian"
    192 128 064; ... % "Child"
    064 000 192; ... % "CartLuggagePram"
    064 128 064; ... % "Animal"
    ]
    
    % "Bicyclist"
    [
    000 128 192; ... % "Bicyclist"
    192 000 192; ... % "MotorcycleScooter"
    ]
    
    };
end
function classes = getClassNames()
classes = [
    "Sky"
    "Building"
    "Pole"
    "Road"
    "Pavement"
    "Tree"
    "SignSymbol"
    "Fence"
    "Car"
    "Pedestrian"
    "Bicyclist"
    ];
end
function pixelLabelColorbar(cmap,classNames)
% Add a colorbar to the current axis. The colorbar is formatted
% to display the class names with the color.

colormap(gca,cmap)

% Add colorbar to current figure.
c = colorbar("peer",gca);

% Use class names for tick marks.
c.TickLabels = classNames;
numClasses = size(cmap,1);

% Center tick labels.
c.Ticks = 1/(numClasses*2):1/numClasses:1;

% Remove tick mark.
c.TickLength = 0;
end
function cmap = camvidColorMap()
% Define the colormap used by CamVid data set.

cmap = [
    128 128 128   % Sky
    128 0 0       % Building
    192 192 192   % Pole
    128 64 128    % Road
    60 40 222     % Pavement
    128 128 0     % Tree
    192 128 128   % SignSymbol
    64 64 128     % Fence
    64 0 128      % Car
    64 64 0       % Pedestrian
    0 128 192     % Bicyclist
    ];

% Normalize between [0 1].
cmap = cmap ./ 255;
end
function [imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = partitionCamVidData(imds,pxds)
% Partition CamVid data by randomly selecting 70% of the data for training,
% 10% of the data for validation, and the rest for testing.
    
% Set initial random state for example reproducibility.
rng(0); 
numFiles = numpartitions(imds);
shuffledIndices = randperm(numFiles);

% Use 70% of the images for training.
numTrain = round(0.70 * numFiles);
trainingIdx = shuffledIndices(1:numTrain);

% Use 10% of the images for validation
numVal = round(0.10 * numFiles);
valIdx = shuffledIndices(numTrain+1:numTrain+numVal);

% Use the rest for testing.
testIdx = shuffledIndices(numTrain+numVal+1:end);

% Create image datastores for training and test.
imdsTrain = subset(imds,trainingIdx);
imdsVal = subset(imds,valIdx);
imdsTest = subset(imds,testIdx);

% Create pixel label datastores for training and test.
pxdsTrain = subset(pxds,trainingIdx);
pxdsVal = subset(pxds,valIdx);
pxdsTest = subset(pxds,testIdx);
end

References

[1] Chen, Liang-Chieh, Yukun Zhu, George Papandreou, Florian Schroff, and Hartwig Adam. “Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation.” Preprint, submitted August 22, 2018. https://arxiv.org/abs/1802.02611.

[2] Brostow, Gabriel J., Julien Fauqueur, and Roberto Cipolla. “Semantic Object Classes in Video: A High-Definition Ground Truth Database.” Pattern Recognition Letters 30, no. 2 (January 2009): 88–97. https://doi.org/10.1016/j.patrec.2008.04.005.

[3] Molchanov, Pavlo, Stephen Tyree, Tero Karras, Timo Aila, and Jan Kautz. "Pruning Convolutional Neural Networks for Resource Efficient Inference." Preprint, submitted June 8, 2017. https://arxiv.org/abs/1611.06440.

[4] Molchanov, Pavlo, Arun Mallya, Stephen Tyree, Iuri Frosio, and Jan Kautz. "Importance Estimation for Neural Network Pruning." In 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 11256–64. Long Beach, CA, USA: IEEE, 2019. https://doi.org/10.1109/CVPR.2019.01152.

See Also

| | | (Deep Learning Toolbox) | (Deep Learning Toolbox)

Related Examples

More About