Prune and Quantize Semantic Segmentation Network
This example shows how to reduce the memory footprint of a semantic segmentation network and speed up inference by compressing the network using pruning and quantization.
Download Pretrained Semantic Segmentation Network
Download a pretrained version of DeepLab v3+ trained on the CamVid data set [1, 2]. For more information about this semantic segmentation network and the data set, see Semantic Segmentation Using Deep Learning (Computer Vision Toolbox).
pretrainedURL = "https://ssd.mathworks.com/supportfiles/vision/data/deeplabv3plusResnet18CamVid_v2.zip"; pretrainedFolder = fullfile(tempdir,"pretrainedNetwork"); pretrainedNetworkZip = fullfile(pretrainedFolder,"deeplabv3plusResnet18CamVid_v2.zip"); if ~exist(pretrainedNetworkZip,"file") mkdir(pretrainedFolder); disp("Downloading pretrained network (58 MB)...") websave(pretrainedNetworkZip,pretrainedURL); end
Downloading pretrained network (58 MB)...
unzip(pretrainedNetworkZip,pretrainedFolder)
Load the pretrained network.
pretrainedNetwork = fullfile(pretrainedFolder,"deeplabv3plusResnet18CamVid_v2.mat");
data = load(pretrainedNetwork);
trainedNet = data.net;
List the classes that this network can classify.
classes = getClassNames()
classes = 11×1 string
"Sky"
"Building"
"Pole"
"Road"
"Pavement"
"Tree"
"SignSymbol"
"Fence"
"Car"
"Pedestrian"
"Bicyclist"
Download CamVid Data Set
Download the CamVid data set [2].
imageURL = "http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip"; labelURL = "http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip"; outputFolder = fullfile(tempdir,"CamVid"); labelsZip = fullfile(outputFolder,"labels.zip"); imagesZip = fullfile(outputFolder,"images.zip"); if ~exist(labelsZip,"file") || ~exist(imagesZip,"file") mkdir(outputFolder) disp("Downloading 16 MB CamVid data set labels...") websave(labelsZip,labelURL); unzip(labelsZip,fullfile(outputFolder,"labels")); disp("Downloading 557 MB CamVid data set images...") websave(imagesZip,imageURL); unzip(imagesZip,fullfile(outputFolder,"images")); end
Load CamVid Pixel-Labeled Images
Use a pixelLabelDatastore
(Computer Vision Toolbox) object to load CamVid pixel label image data. A pixelLabelDatastore
object encapsulates the pixel label data and the label ID to a class name mapping.
To simplify training, group the 32 original classes in CamVid into 11 classes matching the classes that the pretrained DeepLab v3+ network can classify. For example, "Car" is a combination of the CamVid "Car", "SUVPickupTruck", "Truck_Bus", "Train", and "OtherMoving" classes. Return the grouped label IDs by using the camvidPixelLabelIDs
helper function, which is defined at the end of this example.
labelIDs = camvidPixelLabelIDs;
Create the pixelLabelDatastore
using the classes and label IDs.
imgDir = fullfile(outputFolder,"images","701_StillsRaw_full"); imds = imageDatastore(imgDir); labelDir = fullfile(outputFolder,"labels"); pxds = pixelLabelDatastore(labelDir,classes,labelIDs);
Read and display one of the pixel-labeled images by overlaying it on top of an image.
I = readimage(imds,559); C = readimage(pxds,559); cmap = camvidColorMap; B = labeloverlay(I,C,ColorMap=cmap); imshow(B) pixelLabelColorbar(cmap,classes);
Prepare Data for Training
Randomly split the image and pixel label data into training, validation, and test sets. Allocate 70% of the images from the data set to train the Deeplab v3+ model. Allocate 10% of the data for validation and the remaining 20% for testing.
[imdsTrain,imdsVal,imdsTest,pxdsTrain,pxdsVal,pxdsTest] = partitionCamVidData(imds,pxds); dsTrain = combine(imdsTrain,pxdsTrain);
Display the number of training, validation, and test images.
numTrainingImages = numel(imdsTrain.Files)
numTrainingImages = 491
numValImages = numel(imdsVal.Files)
numValImages = 70
numTestingImages = numel(imdsTest.Files)
numTestingImages = 140
Evaluate Network Before Compression
Evaluate the performance of the network by using the evaluateNet
helper function, which is defined at the end of this example. First, the evaluateNet
helper function performs semantic segmentation of the test images. Then, the function calculates metrics that evaluate the quality of the semantic segmentation results against the ground truth segmentation.
trainedNetMetrics = evaluateNet(trainedNet,imdsTest,pxdsTest,classes); trainedNetMetrics.DataSetMetrics
ans=1×5 table
GlobalAccuracy MeanAccuracy MeanIoU WeightedIoU MeanBFScore
______________ ____________ _______ ___________ ___________
0.90852 0.88877 0.69871 0.85013 0.74106
Prune Network
Create Prunable Network
Create a prunable object based on first-order Taylor approximation by using the taylorPrunableNetwork
function. A TaylorPrunableNetwork
object has similar properties and object functions to a dlnetwork
object. However, a TaylorPrunableNetwork
object also has pruning-specific properties and object functions. You can use a TaylorPrunableNetwork
object instead of a dlnetwork
object in a custom training loop. Pruning is iterative, which means that each time the loop runs, the function removes a small number of the least important convolution filters and updates the network architecture. This process continues until a stopping criterion is met.
prunableNet = taylorPrunableNetwork(trainedNet)
prunableNet = TaylorPrunableNetwork with properties: Learnables: [118×3 table] State: [56×3 table] InputNames: {'data'} OutputNames: {'softmax-out'} NumPrunables: 5333
Specify Pruning Options
Set the pruning options.
maxPruningIterations
defines the maximum number of iterations in the pruning loop.maxToPrune
is the maximum number of filters to prune in each iteration of the pruning loop.validationFrequency
is the number of iterations to wait before validating the pruned network using the test data.
maxPruningIterations = 30; maxToPrune = 64; validationFrequency = 5;
Set the fine-tuning options.
maxMinibatchIterations
defines the maximum number of iterations in the fine-tuning loop.Specify the options for stochastic gradient descent with momentum (SGDM) optimization. Specify an initial learning rate of 0.001 and a momentum of 0.9.
Specify a mini-batch size of 8 to fine-tune the network.
maxMinibatchIterations = 40; learnRate = 0.001; momentum = 0.9; miniBatchSize = 8;
Create a figure to monitor the loss, validation accuracy, and number of prunable filters during training.
monitor = trainingProgressMonitor(Metrics=["Loss" "ValAccuracy" "NumPrunables"], ... XLabel="Iteration");
Create Mini-batch Queue
Use a minibatchqueue
object to process and manage the mini-batches of images. For each mini-batch, perform these steps:
Separate the image and label data using the
deal
function.Format the image and label data with the dimension labels
"SSCB"
(spatial, spatial, channel, batch).Train on a GPU if one is available. By default, the
minibatchqueue
object converts each output to agpuArray
if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
mbqTrain = minibatchqueue(dsTrain, ... MiniBatchSize=miniBatchSize, ... MiniBatchFcn=@(img,labels) deal(cat(4,img{:}),cat(4,labels{:})), ... OutputAsDlarray=[1 1], ... MiniBatchFormat=["SSCB" "SSCB"], ... OutputEnvironment=["auto" "cpu"]);
Prune Network Using Custom Training Loop
Prune the network by repeatedly fine-tuning the network and removing the low scoring filters. For each pruning iteration, perform these operations:
Fine-tune the network and accumulate Taylor scores for convolution filters for
maxMinibatchIterations
iterations.Remove a small number of the least important convolution filters and update the network architecture using the
updatePrunables
function.Display the training progress.
To fine tune the network, loop over the mini-batches of the training data. For each mini-batch in the fine-tuning iteration, perform these operations:
Calculate the pruning activations, gradients of the pruning activations, model gradients, state, and loss using the
dlfeval
andmodelGradients
functions. ThemodelGradients
function is a helper function that is defined at the end of the example.Update the network state.
Update the network learnable parameters by using the
sgdmupdate
function.Calculate the first-order Taylor scores and accumulate the score across previous mini-batches of data by using the
updateScore
function.
pruningIteration = 1; metrics = 0; % Pruning loop while (prunableNet.NumPrunables > maxToPrune) && (pruningIteration < maxPruningIterations) % Reset and shuffle the mini-batch reset(mbqTrain); shuffle(mbqTrain); % Reset the parameters for the current pruning iteration velocity = []; localIteration = 0; % Fine-tuning loop while hasdata(mbqTrain) localIteration = localIteration + 1; [dlX,Y] = next(mbqTrain); % Calculate activations of masking layers and gradient of loss with respect to these activations [dLearnables,dGatingLayers,gatingLayerOuts,state,loss] = dlfeval( ... @modelGradients,prunableNet,dlX,Y); prunableNet.State = state; [prunableNet,velocity] = sgdmupdate(prunableNet,dLearnables,velocity,learnRate,momentum); prunableNet = updateScore(prunableNet,dGatingLayers,gatingLayerOuts); if (localIteration > maxMinibatchIterations) break end end % End fine-tuning loop prunableNet = updatePrunables(prunableNet,MaxToPrune=maxToPrune); if (mod(pruningIteration,validationFrequency) == 0 || pruningIteration==1) metrics=evaluateNet(prunableNet,imdsVal,pxdsVal,classes); end pruningIteration = pruningIteration + 1; recordMetrics(monitor,pruningIteration,Loss=loss, ... ValAccuracy=metrics.DataSetMetrics.WeightedIoU,NumPrunables=prunableNet.NumPrunables); end % End pruning loop
During each pruning iteration, the validation accuracy often decreases because of changes in the network structure when the software prunes the convolutional filters. To minimize loss accuracy, retrain the network after pruning.
When pruning is complete, convert the TaylorPrunableNetwork
object back to a dlnetwork
object for retraining and further analysis.
prunedNet = dlnetwork(prunableNet); save("dlnet_pruned.mat","prunedNet");
Compare Filters in Original Network and Pruned Network
Determine the impact of pruning on each layer.
originalNetFilters = numConvLayerFilters(trainedNet);
prunedNetFilters = numConvLayerFilters(prunedNet);
convFilters = join(originalNetFilters,prunedNetFilters,Keys="Row");
Visualize the number of filters in the original network and in the pruned network.
figure(Position=[10,10,900,900]) bar([convFilters.(1),convFilters.(2)]) xlabel("Layer") ylabel("Number of Filters") title("Number of Filters Per Layer") xticks(1:(numel(convFilters.Row))) xticklabels(convFilters.Row) xtickangle(90) ax = gca; ax.TickLabelInterpreter = "none"; legend("Original Network Filters","Pruned Network Filters",Location="southoutside")
Evaluate Pruned Network
Evaluate the pruned network using the test set. The accuracy scores of the pruned network are much lower than the accuracy scores of the original network. You can recover the lost accuracy by retraining the pruned network.
prunedNetMetrics = evaluateNet(prunedNet,imdsTest,pxdsTest,classes); prunedNetMetrics.DataSetMetrics
ans=1×5 table
GlobalAccuracy MeanAccuracy MeanIoU WeightedIoU MeanBFScore
______________ ____________ _______ ___________ ___________
0.92644 0.75779 0.68834 0.86613 0.75825
Retrain Pruned Network
The pruning process can cause the prediction accuracy to decrease. Try to improve the prediction accuracy by retraining the network using a custom training loop.
Specify Training Options
Specify the options to use during retraining.
Specify the options for SGDM optimization. Specify an initial learning rate of 0.001 and a momentum of 0.9. Initialize the gradient velocity as
[]
.
learnRate = 0.001; momentum = 0.9; valFreq = 50; numEpochs = 8; velocity = [];
Create a figure to monitor the loss and validation accuracy during training.
monitor = trainingProgressMonitor(Metrics=["Loss" "ValAccuracy"], ... Info=["Epoch" "LearnRate"],XLabel="Iteration");
Train Pruned Network Using Custom Training Loop
Train the network in a custom training loop. For each iteration:
Evaluate the model gradients using the
dlfeval
function and themodelLoss
helper function, which is defined at the end of the example.Update the network parameters using the
sgdmupdate
function.Display the progress.
epoch = 0; iteration = 0; metrics = []; % Loop over epochs while (epoch < numEpochs) && ~monitor.Stop epoch = epoch + 1; % Shuffle data shuffle(mbqTrain); % Loop over mini-batches while hasdata(mbqTrain) && ~monitor.Stop iteration = iteration + 1; % Read mini-batch of data [X,T] = next(mbqTrain); % Evaluate the model gradients, state, and loss using dlfeval and the % modelLoss function [loss,gradients,state] = dlfeval(@modelLoss,prunedNet,X,T); % Update the network state prunedNet.State = state; % Update the network parameters using the SGDM optimizer [prunedNet,velocity] = sgdmupdate(prunedNet,gradients,velocity,learnRate,momentum); if (mod(iteration,valFreq) == 0 || iteration == 1) reset(imdsVal); reset(pxdsVal); metrics=evaluateNet(prunedNet,imdsVal,pxdsVal,classes,iteration); end % Update the training progress monitor recordMetrics(monitor,iteration,Loss=loss,ValAccuracy=metrics.DataSetMetrics.MeanIoU); updateInfo(monitor,Epoch=epoch,LearnRate=learnRate); end end
save("dlnet_pruned_retrained.mat","prunedNet");
Evaluate Retrained Pruned Network
Evaluate the metrics for the retrained pruned network.
prunedNetMetrics = evaluateNet(prunedNet,imdsTest,pxdsTest,classes); prunedNetMetrics.DataSetMetrics
ans=1×5 table
GlobalAccuracy MeanAccuracy MeanIoU WeightedIoU MeanBFScore
______________ ____________ _______ ___________ ___________
0.93333 0.80252 0.72794 0.87854 0.79206
Compare the memory footprint and accuracy of the original and pruned networks. Using the mean accuracy metric, the pruned network uses 54% less memory than the original network, with a moderate decrease in accuracy.
statsPruned = compareNetworkMetrics(trainedNet,prunedNet, ... trainedNetMetrics.DataSetMetrics.MeanAccuracy,prunedNetMetrics.DataSetMetrics.MeanAccuracy, ... "Pruned Network")
statsPruned=3×3 table
Network Learnables Approx. Network Memory (MB) MeanAccuracy
__________________ ___________________________ ____________
Original Network 1.6402e+07 62.568 0.88877
Pruned Network 7.556e+06 28.824 0.80252
Percentage Change -53.932 -53.932 -9.7047
Quantize Network
Quantize the retrained pruned network for a GPU target. Quantization reduces the memory footprint of the network by converting weights, biases, and activations of convolution layers from floating-point data types to 8-bit scaled integer data types. After quantization, a network can perform inference more quickly.
To improve the performance of the network after quantization, equalize the layer parameters of the retrained pruned network by using the equalizeLayers
function.
eqNet = equalizeLayers(prunedNet);
Create a quantizable version of the retrained pruned network by using a dlquantizer
object. Specify a GPU target using the ExecutionEnvironment
name-value argument.
quantizableNet = dlquantizer(prunedNet,ExecutionEnvironment="GPU");
Calibrate the network with the training data by using the calibrate
function. Calibration consists of exercising the network with sample inputs and collecting dynamic range information.
calibrate(quantizableNet,dsTrain,MiniBatchSize=8);
Quantize the network object and return a simulatable quantized network by using the quantize
function.
quantizedNet = quantize(quantizableNet,ExponentScheme="Histogram");
Display the details of the quantized network by using the quantizationDetails
function.
qDetails = quantizationDetails(quantizedNet)
qDetails = struct with fields:
IsQuantized: 1
TargetLibrary: "cudnn"
QuantizedLayerNames: [91×1 string]
QuantizedLearnables: [51×3 table]
Compare Original and Quantized Networks
Calculate metrics for the quantized semantic segmentation network.
quantizedNetMetrics = evaluateNet(quantizedNet,imdsTest,pxdsTest,classes); quantizedNetMetrics.DataSetMetrics
ans=1×5 table
GlobalAccuracy MeanAccuracy MeanIoU WeightedIoU MeanBFScore
______________ ____________ _______ ___________ ___________
0.92221 0.75959 0.68331 0.85938 0.76537
Compare the original network and the quantized network. The quantized network uses 88.5% less memory than the original network, with a decrease in accuracy.
statsQuantized = compareNetworkMetrics(trainedNet,quantizedNet, ... trainedNetMetrics.DataSetMetrics.MeanAccuracy,quantizedNetMetrics.DataSetMetrics.MeanAccuracy, ... "Quantized Network")
statsQuantized=3×3 table
Network Learnables Approx. Network Memory (MB) MeanAccuracy
__________________ ___________________________ ____________
Original Network 1.6402e+07 62.568 0.88877
Quantized Network 7.556e+06 7.206 0.75959
Percentage Change -53.932 -88.483 -14.535
Results Summary
Compare the number of learnables, memory footprint, and accuracy of the original, pruned, and quantized networks.
[statsPruned(1:2,:); statsQuantized(2,:)]
ans=3×3 table
Network Learnables Approx. Network Memory (MB) MeanAccuracy
__________________ ___________________________ ____________
Original Network 1.6402e+07 62.568 0.88877
Pruned Network 7.556e+06 28.824 0.80252
Quantized Network 7.556e+06 7.206 0.75959
Helper Functions
function [loss,gradients,state] = modelLoss(net,X,T) % Calculate semantic segmentation model loss. % Forward data through network [Y,state] = forward(net,X); T = extractdata(T); T_Onehotencode = onehotencode(T,3,ClassNames=1:11); T_Onehotencode(isnan(T_Onehotencode)) = 0; numObs = size(Y,1) * size(Y,2) * size(Y,4); % Calculate cross-entropy loss loss = crossentropy(Y,T_Onehotencode)/numObs; % Calculate gradients of loss with respect to learnable parameters gradients = dlgradient(loss,net.Learnables); end
function [dLossdLearnables,pruningGradient,pruningActivations,state,loss] = modelGradients(networkPruner,dlX,Y) % Calculate network pruning model gradients. [networkAct,state,pruningActivations] = forward(networkPruner,dlX); % Get the output of softmax and use it to compute loss Y = extractdata(Y); Y2 = onehotencode(Y,3,ClassNames=1:11); Y2(isnan(Y2)) = 0; Y2 = dlarray(Y2,"SSCB"); dims = size(networkAct); bz = dims(end); loss = crossentropy(networkAct,Y2)/(prod(dims(1:2))*bz); % Retrieve outputs of gating layers and network outputs % differentiate loss w.r.t learnables and gating layers [dLossdLearnables,pruningGradient] = dlgradient(loss,networkPruner.Learnables,pruningActivations); end
function convFilters = numConvLayerFilters(net) % Return the number of filters in each convolution layer of a network. numLayers = numel(net.Layers); convNames = []; numFilters = []; % Check for convolution layers and extract the number of filters. for cnt = 1:numLayers if isa(net.Layers(cnt),"nnet.cnn.layer.Convolution2DLayer") sizeW = size(net.Layers(cnt).Weights); numFilters = [numFilters; sizeW(end)]; %#ok<AGROW> convNames = [convNames; string(net.Layers(cnt).Name)]; %#ok<AGROW> end end convFilters = table(numFilters,RowNames=convNames); end
function ssm = evaluateNet(net,imds,pxdsTruth,classNames,itr) % Apply semantic segmentation and evalate the segmentation results. dirname = tempdir; if nargin==5 dirname = dirname + "val_" + num2str(itr); mkdir(dirname); end pxdsResults = semanticseg(imds,net,WriteLocation=dirname,Classes=classNames,Verbose=false,MinibatchSize=8); ssm = evaluateSemanticSegmentation(pxdsResults,pxdsTruth,Verbose=false); end
function statistics = compareNetworkMetrics(originalNet,compressedNet, ... orginalNetAccuracy,compressedNetAccuracy,compressedNetName) % Return statistics about a network including the type of network, number of learnables, size, and accuracy. originalNetMetrics = estimateNetworkMetrics(originalNet); prunedNetMetrics = estimateNetworkMetrics(compressedNet); % Accuracy of original network and pruned network perChangeAccu = 100*(compressedNetAccuracy - orginalNetAccuracy)/orginalNetAccuracy; accuracyForNetworks = [orginalNetAccuracy;compressedNetAccuracy;perChangeAccu]; % Total learnables in both networks originalNetLearnables = sum(originalNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables); prunedNetLearnables = sum(prunedNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables); percentageChangeLearnables = 100*(prunedNetLearnables - originalNetLearnables)/originalNetLearnables; learnablesForNetwork = [originalNetLearnables;prunedNetLearnables;percentageChangeLearnables]; % Approximate parameter memory approxOriginalMemory = sum(originalNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)")); approxPrunedMemory = sum(prunedNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)")); percentageChangeMemory = 100*(approxPrunedMemory - approxOriginalMemory)/approxOriginalMemory; networkMemory = [approxOriginalMemory; approxPrunedMemory; percentageChangeMemory]; % Create the summary table statistics = table(learnablesForNetwork,networkMemory,accuracyForNetworks, ... VariableNames=["Network Learnables","Approx. Network Memory (MB)","MeanAccuracy"], ... RowNames=["Original Network",compressedNetName,"Percentage Change"]); end
function labelIDs = camvidPixelLabelIDs() % Return the label IDs corresponding to each class. % % The CamVid data set has 32 classes. Group them into 11 classes following % the original SegNet training methodology [1]. % % The 11 classes are: % "Sky" "Building", "Pole", "Road", "Pavement", "Tree", "SignSymbol", % "Fence", "Car", "Pedestrian", and "Bicyclist". % % CamVid pixel label IDs are provided as RGB color values. Group them into % 11 classes and return them as a cell array of M-by-3 matrices. The % original CamVid class names are listed alongside each RGB value. Note % that the Other/Void class are excluded below. labelIDs = { ... % "Sky" [ 128 128 128; ... % "Sky" ] % "Building" [ 000 128 064; ... % "Bridge" 128 000 000; ... % "Building" 064 192 000; ... % "Wall" 064 000 064; ... % "Tunnel" 192 000 128; ... % "Archway" ] % "Pole" [ 192 192 128; ... % "Column_Pole" 000 000 064; ... % "TrafficCone" ] % Road [ 128 064 128; ... % "Road" 128 000 192; ... % "LaneMkgsDriv" 192 000 064; ... % "LaneMkgsNonDriv" ] % "Pavement" [ 000 000 192; ... % "Sidewalk" 064 192 128; ... % "ParkingBlock" 128 128 192; ... % "RoadShoulder" ] % "Tree" [ 128 128 000; ... % "Tree" 192 192 000; ... % "VegetationMisc" ] % "SignSymbol" [ 192 128 128; ... % "SignSymbol" 128 128 064; ... % "Misc_Text" 000 064 064; ... % "TrafficLight" ] % "Fence" [ 064 064 128; ... % "Fence" ] % "Car" [ 064 000 128; ... % "Car" 064 128 192; ... % "SUVPickupTruck" 192 128 192; ... % "Truck_Bus" 192 064 128; ... % "Train" 128 064 064; ... % "OtherMoving" ] % "Pedestrian" [ 064 064 000; ... % "Pedestrian" 192 128 064; ... % "Child" 064 000 192; ... % "CartLuggagePram" 064 128 064; ... % "Animal" ] % "Bicyclist" [ 000 128 192; ... % "Bicyclist" 192 000 192; ... % "MotorcycleScooter" ] }; end
function classes = getClassNames() classes = [ "Sky" "Building" "Pole" "Road" "Pavement" "Tree" "SignSymbol" "Fence" "Car" "Pedestrian" "Bicyclist" ]; end
function pixelLabelColorbar(cmap,classNames) % Add a colorbar to the current axis. The colorbar is formatted % to display the class names with the color. colormap(gca,cmap) % Add colorbar to current figure. c = colorbar("peer",gca); % Use class names for tick marks. c.TickLabels = classNames; numClasses = size(cmap,1); % Center tick labels. c.Ticks = 1/(numClasses*2):1/numClasses:1; % Remove tick mark. c.TickLength = 0; end
function cmap = camvidColorMap() % Define the colormap used by CamVid data set. cmap = [ 128 128 128 % Sky 128 0 0 % Building 192 192 192 % Pole 128 64 128 % Road 60 40 222 % Pavement 128 128 0 % Tree 192 128 128 % SignSymbol 64 64 128 % Fence 64 0 128 % Car 64 64 0 % Pedestrian 0 128 192 % Bicyclist ]; % Normalize between [0 1]. cmap = cmap ./ 255; end
function [imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = partitionCamVidData(imds,pxds) % Partition CamVid data by randomly selecting 70% of the data for training, % 10% of the data for validation, and the rest for testing. % Set initial random state for example reproducibility. rng(0); numFiles = numpartitions(imds); shuffledIndices = randperm(numFiles); % Use 70% of the images for training. numTrain = round(0.70 * numFiles); trainingIdx = shuffledIndices(1:numTrain); % Use 10% of the images for validation numVal = round(0.10 * numFiles); valIdx = shuffledIndices(numTrain+1:numTrain+numVal); % Use the rest for testing. testIdx = shuffledIndices(numTrain+numVal+1:end); % Create image datastores for training and test. imdsTrain = subset(imds,trainingIdx); imdsVal = subset(imds,valIdx); imdsTest = subset(imds,testIdx); % Create pixel label datastores for training and test. pxdsTrain = subset(pxds,trainingIdx); pxdsVal = subset(pxds,valIdx); pxdsTest = subset(pxds,testIdx); end
References
[1] Chen, Liang-Chieh, Yukun Zhu, George Papandreou, Florian Schroff, and Hartwig Adam. “Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation.” Preprint, submitted August 22, 2018. https://arxiv.org/abs/1802.02611.
[2] Brostow, Gabriel J., Julien Fauqueur, and Roberto Cipolla. “Semantic Object Classes in Video: A High-Definition Ground Truth Database.” Pattern Recognition Letters 30, no. 2 (January 2009): 88–97. https://doi.org/10.1016/j.patrec.2008.04.005.
[3] Molchanov, Pavlo, Stephen Tyree, Tero Karras, Timo Aila, and Jan Kautz. "Pruning Convolutional Neural Networks for Resource Efficient Inference." Preprint, submitted June 8, 2017. https://arxiv.org/abs/1611.06440.
[4] Molchanov, Pavlo, Arun Mallya, Stephen Tyree, Iuri Frosio, and Jan Kautz. "Importance Estimation for Neural Network Pruning." In 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 11256–64. Long Beach, CA, USA: IEEE, 2019. https://doi.org/10.1109/CVPR.2019.01152.
See Also
pixelLabelDatastore
(Computer Vision Toolbox) | semanticseg
(Computer Vision Toolbox) | evaluateSemanticSegmentation
(Computer Vision Toolbox) | taylorPrunableNetwork
| dlquantizer