Main Content

Object Detection Using YOLO v3 Deep Learning

This example shows how to detect objects in images using you only look once version 3 (YOLO v3) deep learning network. In this example, you will

  • Configure a dataset for training and testing of YOLO v3 object detection network. You will also perform data augmentation on the training dataset to improve the network efficiency.

  • Create a YOLO v3 object detector by using the yolov3ObjectDetector (Computer Vision Toolbox) function and train the detector using trainYOLOv3ObjectDetector (Computer Vision Toolbox) function.

This example also provides a pretrained YOLO v3 object detector to use for detecting vehicles in an image. The pretrained network uses squeezenet as the backbone network and is trained on a vehicle dataset. For information about YOLO v3 object detection network, see Getting Started with YOLO v3 (Computer Vision Toolbox).

Load Data

This example uses a small labeled data set that contains 295 images. Many of these images come from the Caltech Cars 1999 and 2001 data sets, created by Pietro Perona and used with permission. Each image contains one or two labeled instances of a vehicle. A small data set is useful for exploring the YOLO v3 training procedure, but in practice, more labeled images are needed to train a robust network.

Unzip the vehicle images and load the vehicle ground truth data.

unzip vehicleDatasetImages.zip
data = load("vehicleDatasetGroundTruth.mat");
vehicleDataset = data.vehicleDataset;

The vehicle data is stored in a two-column table. The first column contain the image file paths and the second column contain the bounding boxes.

Display first few rows of the data set.

vehicleDataset(1:4,:)
ans=4×2 table
              imageFilename                   vehicle     
    _________________________________    _________________

    {'vehicleImages/image_00001.jpg'}    {[220 136 35 28]}
    {'vehicleImages/image_00002.jpg'}    {[175 126 61 45]}
    {'vehicleImages/image_00003.jpg'}    {[108 120 45 33]}
    {'vehicleImages/image_00004.jpg'}    {[124 112 38 36]}

Add the full path to the local vehicle data folder.

vehicleDataset.imageFilename = fullfile(pwd,vehicleDataset.imageFilename);

Split the dataset into training, validation, and test sets. Select 60% of the data for training, 10% for validation, and the rest for testing the trained detector.

rng(0);
shuffledIndices = randperm(height(vehicleDataset));
idx = floor(0.6 * length(shuffledIndices) );

trainingIdx = 1:idx;
trainingDataTbl = vehicleDataset(shuffledIndices(trainingIdx),:);

validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl = vehicleDataset(shuffledIndices(validationIdx),:);

testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = vehicleDataset(shuffledIndices(testIdx),:);

Use imageDatastore and boxLabelDatastore to create datastores for loading the image and label data during training and evaluation.

imdsTrain = imageDatastore(trainingDataTbl{:,"imageFilename"});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,"vehicle"));

imdsValidation = imageDatastore(validationDataTbl{:,"imageFilename"});
bldsValidation = boxLabelDatastore(validationDataTbl(:,"vehicle"));

imdsTest = imageDatastore(testDataTbl{:,"imageFilename"});
bldsTest = boxLabelDatastore(testDataTbl(:,"vehicle"));

Combine image and box label datastores.

trainingData = combine(imdsTrain,bldsTrain);
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);

Use validateInputData to detect invalid images, bounding boxes or labels i.e.,

  • Samples with invalid image format or containing NaNs

  • Bounding boxes containing zeros/NaNs/Infs/empty

  • Missing/non-categorical labels.

The values of the bounding boxes should be finite, positive, non-fractional, non-NaN and should be within the image boundary with a positive height and width. Any invalid samples must either be discarded or fixed for proper training.

validateInputData(trainingData);
validateInputData(validationData);
validateInputData(testData);

Data Augmentation

Data augmentation is used to improve network accuracy by randomly transforming the original data during training. By using data augmentation, you can add more variety to the training data without actually having to increase the number of labeled training samples.

Use transform function to apply custom data augmentations to the training data. The augmentData helper function, listed at the end of the example, applies the following augmentations to the input data.

  • Color jitter augmentation in HSV space

  • Random horizontal flip

  • Random scaling by 10 percent

augmentedTrainingData = transform(trainingData,@augmentData);

Read the same image four times and display the augmented training data.

augmentedData = cell(4,1);
for k = 1:4
    data = read(augmentedTrainingData);
    augmentedData{k} = insertShape(data{1,1},"Rectangle",data{1,2});
    reset(augmentedTrainingData);
end
figure
montage(augmentedData,BorderSize=10)

Figure contains an axes object. The hidden axes object contains an object of type image.

reset(augmentedTrainingData);

Define YOLO v3 Object Detector

The YOLO v3 detector in this example is based on SqueezeNet, and uses the feature extraction network in SqueezeNet with the addition of two detection heads at the end. The second detection head is twice the size of the first detection head, so it is better able to detect small objects. Note that you can specify any number of detection heads of different sizes based on the size of the objects that you want to detect. The YOLO v3 detector uses anchor boxes estimated using training data to have better initial priors corresponding to the type of data set and to help the detector learn to predict the boxes accurately. For information about anchor boxes, see Anchor Boxes for Object Detection (Computer Vision Toolbox).

The YOLO v3 network present in the YOLO v3 detector is illustrated in the following diagram.

You can use Deep Network Designer to create the network shown in the diagram.

Specify the network input size. When choosing the network input size, consider the minimum size required to run the network itself, the size of the training images, and the computational cost incurred by processing data at the selected size. When feasible, choose a network input size that is close to the size of the training image and larger than the input size required for the network. To reduce the computational cost of running the example, specify a network input size of [227 227 3].

networkInputSize = [227 227 3];

First, use transform to preprocess the training data for computing the anchor boxes, as the training images used in this example are bigger than 227-by-227 and vary in size. Specify the number of anchors as 6 to achieve a good tradeoff between number of anchors and mean IoU. Use the estimateAnchorBoxes function to estimate the anchor boxes. For details on estimating anchor boxes, see Estimate Anchor Boxes From Training Data (Computer Vision Toolbox). In case of using a pretrained YOLOv3 object detector, the anchor boxes calculated on that particular training dataset need to be specified. Note that the estimation process is not deterministic. To prevent the estimated anchor boxes from changing while tuning other hyperparameters, set the random seed prior to estimation using rng.

rng(0)
trainingDataForEstimation = transform(trainingData,@(data)preprocessData(data,networkInputSize));
numAnchors = 6;
[anchors, meanIoU] = estimateAnchorBoxes(trainingDataForEstimation,numAnchors)
anchors = 6×2

    42    37
   160   130
    96    92
   141   123
    35    25
    69    66

meanIoU = 
0.8430

Specify anchorBoxes to use in both the detection heads. anchorBoxes is a cell array of [Mx1], where M denotes the number of detection heads. Each detection head consists of a [Nx2] matrix of anchors, where N is the number of anchors to use. Select anchorBoxes for each detection head based on the feature map size. Use larger anchors at lower scale and smaller anchors at higher scale. To do so, sort the anchors with the larger anchor boxes first and assign the first three to the first detection head and the next three to the second detection head.

area = anchors(:, 1).*anchors(:, 2);
[~, idx] = sort(area,"descend");
anchors = anchors(idx, :);
anchorBoxes = {anchors(1:3,:)
    anchors(4:6,:)
    };

Load the SqueezeNet network pretrained on Imagenet data set and then specify the class names. You can also choose to load a different pretrained network trained on COCO data set such as tiny-yolov3-coco or darknet53-coco or Imagenet data set such as MobileNet-v2 or ResNet-18. YOLO v3 performs better and trains faster when you use a pretrained network.

baseNetwork = imagePretrainedNetwork("squeezenet");
classNames = trainingDataTbl.Properties.VariableNames(2:end);

Next, create the yolov3ObjectDetector object by adding the detection network source. Choosing the optimal detection network source requires trial and error, and you can use analyzeNetwork to find the names of potential detection network source within a network. For this example, use the fire9-concat and fire5-concat layers as DetectionNetworkSource.

yolov3Detector = yolov3ObjectDetector(baseNetwork,classNames,anchorBoxes, ...
    DetectionNetworkSource=["fire9-concat","fire5-concat"],InputSize=networkInputSize);

Alternatively, instead of the network created above using SqueezeNet, other pretrained YOLOv3 architectures trained using larger datasets like MS-COCO can be used to transfer learn the detector on custom object detection task. Transfer learning can be realized by changing the classNames and anchorBoxes.

Specify Training Options

Use trainingOptions to specify network training options. Train the object detector using the Adam solver for 80 epochs with a constant learning rate 0.001. Set ValidationData to the validation data and ValidationFrequency to 1000. To validate the data more often, you can reduce the ValidationFrequency which also increases the training time. Update Metrics to "mAPObjectDetectionMetric" to track the mean average precision (mAP) metric while training YOLO v3. Use ExecutionEnvironment to determine what hardware resources will be used to train the network. The default value for ExecutionEnvironment is "auto", which selects a GPU if it is available, and otherwise selects the CPU. Set CheckpointPath to a temporary location to enable the saving of partially trained detectors during the training process. If training is interrupted, for instance by a power outage or system failure, you can resume training from the saved checkpoint.

options = trainingOptions("adam", ...
    GradientDecayFactor=0.9, ...
    SquaredGradientDecayFactor=0.999, ...
    InitialLearnRate=0.001, ...
    LearnRateSchedule="none", ...
    MiniBatchSize=8, ...
    L2Regularization=0.0005, ...
    MaxEpochs=80, ...
    DispatchInBackground=true, ...
    ResetInputNormalization=true, ...
    Shuffle="every-epoch", ...
    VerboseFrequency=20, ...
    ValidationFrequency=1000, ...
    CheckpointPath=tempdir, ...
    ValidationData=validationData);

% Set metrics to mAPObjectDetectionMetric
options.Metrics = mAPObjectDetectionMetric("Name","map50");

Train YOLO v3 Object Detector

Use the trainYOLOv3ObjectDetector (Computer Vision Toolbox) function to train the YOLO v3 object detector. Instead of training the network, you can also use a pretrained YOLO v3 object detector.

Download a pretrained network by using the helper function downloadPretrainedYOLOv3Detector. If you want to train the network with a new set of data, set the doTraining variable to true.

doTraining = false;
if doTraining       
    % Train the YOLO v3 detector.
    [yolov3Detector,info] = trainYOLOv3ObjectDetector(augmentedTrainingData,yolov3Detector,options);
else
    % Load pretrained detector for the example.
    yolov3Detector = downloadPretrainedYOLOv3Detector();
end
Downloading pretrained detector...

Evaluate Model

Run the detector on all the test images. Set the detection threshold to a low value to detect as many objects as possible. This helps you evaluate the detector precision across the full range of recall values.

results = detect(yolov3Detector,testData,MiniBatchSize=8,Threshold=0.01);

Calculate the object detection performance metrics on the test set detection results using the evaluateObjectDetection (Computer Vision Toolbox) function.

metrics = evaluateObjectDetection(results,testData);

The average precision (AP) provides a single number that incorporates the ability of the detector to make correct classifications (precision) and the ability of the detector to find all relevant objects (recall). The precision-recall (PR) curve shows how precise a detector is at varying levels of recall. Ideally, the precision is 1 at all recall levels. Plot the PR curve and also display the average precision.

[precision,recall] = precisionRecall(metrics);
AP = averagePrecision(metrics);

figure
plot(recall{:},precision{:})
xlabel("Recall")
ylabel("Precision")
grid on
title("Average Precision = "+AP)

Figure contains an axes object. The axes object with title Average Precision = 0.91109, xlabel Recall, ylabel Precision contains an object of type line.

Detect Objects Using YOLO v3

Use the detector for object detection.

% Read the datastore.
data = read(testData);

% Get the image.
I = data{1};

[bboxes,scores,labels] = detect(yolov3Detector,I);

% Display the detections on image.
I = insertObjectAnnotation(I,"rectangle",bboxes,scores);

figure
imshow(I)

Figure contains an axes object. The hidden axes object contains an object of type image.

Supporting Functions

Helper function for performing data augmentation.

function data = augmentData(A)
% Apply random horizontal flipping, and random X/Y scaling. Boxes that get
% scaled outside the bounds are clipped if the overlap is above 0.25. Also,
% jitter image color.

data = cell(size(A));
for ii = 1:size(A,1)
    I = A{ii,1};
    bboxes = A{ii,2};
    labels = A{ii,3};
    sz = size(I);

    if numel(sz) == 3 && sz(3) == 3
        I = jitterColorHSV(I,...
            'Contrast',0.0,...
            'Hue',0.1,...
            'Saturation',0.2,...
            'Brightness',0.2);
    end
    
    % Randomly flip image.
    tform = randomAffine2d("XReflection",true,"Scale",[1 1.1]);
    rout = affineOutputView(sz,tform,"BoundsStyle","centerOutput");
    I = imwarp(I,tform,"OutputView",rout);
    
    % Apply same transform to boxes.
    [bboxes,indices] = bboxwarp(bboxes,tform,rout,"OverlapThreshold",0.25);
    bboxes = round(bboxes);
    labels = labels(indices);
    
    % Return original data only when all boxes are removed by warping.
    if isempty(indices)
        data(ii,:) = A(ii,:);
    else
        data(ii,:) = {I,bboxes,labels};
    end
end
end

Helper function for data preprocessing.

function data = preprocessData(data,targetSize)
% Resize the images and scale the pixels to between 0 and 1. Also scale the
% corresponding bounding boxes.

for ii = 1:size(data,1)
    I = data{ii,1};
    imgSize = size(I);
    
    % Convert an input image with single channel to 3 channels.
    if numel(imgSize) < 3 
        I = repmat(I,1,1,3);
    end
    bboxes = data{ii,2};

    I = im2single(imresize(I,targetSize(1:2)));
    scale = targetSize(1:2)./imgSize(1:2);
    bboxes = bboxresize(bboxes,scale);
    
    data(ii, 1:2) = {I, bboxes};
end
end

Helper function for downloading the pretrained YOLO v3 object detector.

function detector = downloadPretrainedYOLOv3Detector()
% Download a pretrained yolov3 detector.
if ~exist("yolov3SqueezeNetVehicleExample_21aSPKG.mat", "file")
    if ~exist("yolov3SqueezeNetVehicleExample_21aSPKG.zip", "file")
        disp("Downloading pretrained detector...");
        pretrainedURL = "https://ssd.mathworks.com/supportfiles/vision/data/yolov3SqueezeNetVehicleExample_21aSPKG.zip";
        websave("yolov3SqueezeNetVehicleExample_21aSPKG.zip", pretrainedURL);
    end
    unzip("yolov3SqueezeNetVehicleExample_21aSPKG.zip");
end
pretrained = load("yolov3SqueezeNetVehicleExample_21aSPKG.mat");
detector = pretrained.detector;
end

References

[1] Redmon, Joseph, and Ali Farhadi. “YOLOv3: An Incremental Improvement.” Preprint, submitted April 8, 2018. https://arxiv.org/abs/1804.02767.

See Also

Apps

Functions

Objects

Related Topics