Main Content

Detect Small Objects Using Tiled Training of YOLOX Network

Since R2024b

Detecting small objects in full resolution images is a modern challenge in object detection [1]. An object can be small in terms of the size of the object, in pixels, relative to the overall dimensions of the image, or the number of pixels the object contains. Using full-resolution images to train a network for small object detection often becomes computationally intensive, and downsampling training data that contains objects spanning a small number of pixels can lead to significant information loss in detection and training. To address this issue, you can tile your training images. In this tiled training approach, you split the images in the data set into full-resolution segments, or tiles, at training time, which preserves the full spatial resolution in the trained detector [2].

In this example, you train a YOLOX object detection network using tiled images, and use this trained network to detect small cows in aerial images of agriculture, using one-shot inference at full image resolution. Whereas you typically perform inference at the resolution of the tile, requiring multiple inference calls per image, the one-shot inference approach in this example requires no additional post-processing, enabling you to deploy it directly on hardware using code generation. Full-resolution inference can be fast, even on relatively large input images, primarily constrained by the amount of GPU or CPU memory available for the computation.

Download Pretrained YOLOX Network

By default, this example downloads a pretrained version of the YOLOX object detector [3]. You can use the pretrained network to run the entire example without waiting for training to complete. The fully convolutional architecture of YOLOX enables tiled training and full-sized inference with one-shot small object detection.

unzip("https://ssd.mathworks.com/supportfiles/vision/data/aerialObjectDetectionModel.zip",tempdir);
load(fullfile(tempdir,"AerialObjectDetectionModel","trainedSmallCowDetectorYOLOX.mat"));

Download Aerial Cows Data Set

This example uses the aerial cows data set [4]. The data set contains 1,723 images of agricultural terrain photographed from an aerial perspective, many of which feature cows. The data contains 1,084 training images. Each of the images in the data set is approximately 1530-by-2720-by-3 pixels. The cows in the scene are very small relative to the image size. The median size of the cow bounding boxes in the ground truth data is 30-by-32 pixels.

Specify dataDir as the location of the data set. Download the data set and unzip the contents of the folder into the specified location.

unzip("https://ssd.mathworks.com/supportfiles/vision/data/aerialCowsDataset.zip",tempdir)
dataDir = fullfile(tempdir,"aerialCows");

Detect Small Objects Using Pretrained Network

Read a sample image from the data set.

I = imread("testAerialImage.jpg");

Predict the mask, labels, and confidence scores for each object instance using the detect object function. Specify AutoResize as false to perform the inference at full image size.

[boxes,scores,labels] = detect(detector,I,Threshold=0.1,AutoResize=false);

Display the object annotations overlaid on the image using the insertObjectAnnotation function.

figure
imshow(insertObjectAnnotation(I,"rectangle",boxes,labels))

Prepare Data for Training

Partition Data into Training and Validation Sets

Prior to partitioning the data, set the global random state to 0 to ensure reproducibility.

rng(0)

Create an imageDatastore object that holds the training set, from the train folder of the downloaded aerial cow data set. Create a datastore, trainDS, that associates images with the corresponding ground truth box labels for cows in the training set.

labelMapping = load(fullfile(dataDir,"aerialCowLabels.mat"));
imds = imageDatastore(fullfile(dataDir,"train",labelMapping.trainTbl.filename));
blds = boxLabelDatastore(labelMapping.trainTbl(:,2:end));
trainDS = transform(imds,blds,@(img,lbl) {img lbl{1} lbl{2}});

Create an imageDatastore object that holds the validation set, from the valid folder of the downloaded aerial cow data set. Create a datastore, valDS, that associates images with the corresponding ground truth box labels for cows in the validation set.

imdsVal = imageDatastore(fullfile(dataDir,"valid",labelMapping.valTbl.filename));
bldsVal = boxLabelDatastore(labelMapping.valTbl(:,2:end));
valDS = transform(imdsVal,bldsVal,@(img,lbl) {img lbl{1} lbl{2}});

Determine Object Size Distribution

Display a scatter plot of the object sizes to determine their distribution. The median width and height of the boxes are 30 and 32 pixels, respectively. The minimum width and height of the boxes are 1 and 4 pixels, respectively.

trainingBoxes = readall(blds);
overallBoxVector = vertcat(trainingBoxes{:,1});
figure
plot(overallBoxVector(:,3),overallBoxVector(:,4),"x")
xlabel("Bounding Box Widths")
ylabel("Bounding Box Heights")

Split Training and Validation Data into Tiles

Create a blockedImage object from the images in the training set. The blockedImage object represents a very large image as a collection of smaller blocks.

bim = blockedImage(imds.Files);

Select image tiles of size [512 512 3] from the dataset with a 50% overlap ratio using the selectBlockLocationsUsingBoxes helper function. Some tiles contain box labels that are partially cut off at the tile edges. In the case of box labels that are partially contained, the function selects tiles in which the ratio of the box area within the image bounds to the total box area is at least 50% for all boxes in the tile.

[bls,boxLabelsOut] = selectBlockLocationsUsingBoxes(bim,readall(blds),BlockSize=[512 512 3],BlockOffsets=[256 256],OverlapThreshold=0.5);

To ensure that the detector encounters a variety of background content, including features such as mountains or clouds that might be spatially far from where cows appear in the scene, select 30% as many background tiles as tiles that contain box labels.

[blsBackground,boxLabelsOutBackground] = selectBlockLocationsUsingBoxes(bim,readall(blds),BlockSize=[512 512 3],BlockOffsets=[512 512],SelectBackgroundTiles=true);
indicesToSelect = randperm(size(boxLabelsOutBackground,1));
indicesToSelect = indicesToSelect(1:round(0.3*size(boxLabelsOut,1)));

Create a blockLocationSet object that contains the specified block and their locations.

blsOverall = blockLocationSet(vertcat(bls.ImageNumber,blsBackground.ImageNumber(indicesToSelect)), ...
    vertcat(bls.BlockOrigin,blsBackground.BlockOrigin(indicesToSelect,:)), ...
    [512 512 3]);

Create a blockedImageDatastore object for the tiled training images. Specify the blocks to include in the datastore as the new blockLocationSet object blsOverall.

bimds = blockedImageDatastore(bim,BlockLocationSet=blsOverall);

Create a boxLabelDatastore object for the bounding box label data corresponding to the tiled training images.

bldsForBlocks = boxLabelDatastore(vertcat(boxLabelsOut,boxLabelsOutBackground));
trainingSetBlocked = transform(bimds,bldsForBlocks,@(im,boxes) {im{1} boxes{1} boxes{2}});

Display a sample image tile from the tiled training set.

x = preview(trainingSetBlocked);
figure
imshow(insertObjectAnnotation(x{1},"rectangle",x{2},x{3}))

Repeat the process of creating a tiled datastore of images and labels for the validation set.

bimVal = blockedImage(imdsVal.Files);
[blsVal,boxLabelsOutVal] = selectBlockLocationsUsingBoxes(bimVal,readall(bldsVal),BlockSize=[512 512],BlockOffsets=[256 256]);
bimdsVal = blockedImageDatastore(bimVal,BlockLocationSet=blsVal);
bldsForBlocksVal = boxLabelDatastore(boxLabelsOutVal);
valSetBlocked = transform(bimdsVal,bldsForBlocksVal,@(im,boxes) {im{1} boxes{1} boxes{2}});

Display a sample image tile from the tiled validation set.

x = preview(valSetBlocked);
figure
imshow(insertObjectAnnotation(x{1},"rectangle",x{2},x{3}))

Write Data to Disk to Improve Training Speed

To increase training speed, create an on-disk representation of the blocked image data. To keep the sections of each image for each tile in memory, write each tile with its corresponding box labels to disk so that you can read the tiled training set into memory as efficiently as possible. To write a large set of tiled images to disk, ensure that you have 8.6 GB of additional disk space available. You can perform training without writing the images to disk at the expense of significantly longer training time.

Write the image and corresponding box label information for each tile to disk. Create separate MAT files for the training and validation sets by using the writeAsMAT helper function.

outputLocationTrain = fullfile(dataDir,"trainingSetTiled");
writeAsMAT(trainingSetBlocked,outputLocationTrain)

outputLocationVal = fullfile(dataDir,"validationSetTiled");
writeAsMAT(valSetBlocked,outputLocationVal)

Configure Reading of Tiled Data

Read the training and validation tiled data set from disk at training time by using fileDatastore objects. The datastores read each MAT file, and return the labeled object detection data.

dsTrain = fileDatastore(outputLocationTrain,ReadFcn=@load);
dsTrain = transform(dsTrain,@(x) x.imageBoxesLabelsCell);

dsVal = fileDatastore(outputLocationVal,ReadFcn=@load);
dsVal = transform(dsVal,@(x) x.imageBoxesLabelsCell);

subsetIndices = randperm(numpartitions(dsVal));
dsVal = subset(dsVal,subsetIndices(1:round(0.5*numpartitions(dsVal)))); % Half of the data for speed

Augment Training Data

Augment the training data by using the augmentData helper function, which applies a horizontal and vertical reflection and resizes the input data by a scale factor in the range [1, 1.1] to the input data. To ensure the learned filters in the object detector generalize well to full-sized inference, use augmentation that does not require fill values to define out-of-bounds points when resampling.

dsTrain = transform(dsTrain,@augmentData);

Define YOLOX Object Detector Network Architecture

Create the YOLOX object detector by using the yoloxObjectDetector function. Specify the pretrained network as "small-coco", which uses CSP-DarkNet-53 as the base network and is trained on the COCO data set. Specify the class name and a network input size equal to the tile size used in the tiled training set to ensure that the network does not perform any resizing during training.

detector = yoloxObjectDetector("small-coco","cow",InputSize=[512 512 3]);

Specify Training Options

Specify network training options using the trainingOptions (Deep Learning Toolbox) function. Specify the Metrics argument using the mAPObjectDetectionMetric function to monitor the mean average precision (mAP) of the detector during training. Specify the OutputNetwork name-value argument as "best-validation" and the ObjectiveMetricName name-value argument as "mAP50" to return the trained network corresponding to the training iteration with the maximal mAP value.

options = trainingOptions("sgdm", ...
    InitialLearnRate=5e-4, ...
    LearnRateSchedule="piecewise", ...
    LearnRateDropFactor=0.98, ...
    LearnRateDropPeriod=1, ...
    MiniBatchSize=50, ...
    MaxEpochs=20, ...
    Shuffle="every-epoch", ...
    VerboseFrequency=1, ...
    ValidationFrequency=200, ...
    ValidationData=dsVal, ...
    ValidationPatience=5, ...
    OutputNetwork = "best-validation", ...
    Metrics=mAPObjectDetectionMetric(Name="mAP50"), ...
    L2Regularization=5e-4, ...
    ObjectiveMetricName="mAP50");

Train Detector

To train the detector, set doTraining to true. Train the detector by using the trainYOLOXObjectDetector function.

Train on one or more GPUs, if they are available. Using a GPU requires a Parallel Computing Toolbox™ license and a CUDA® enabled NVIDIA® GPU. For more information, see GPU Computing Requirements (Parallel Computing Toolbox). Training takes about 1 hour on an NVIDIA Titan RTX™ with 24 GB of memory.

doTraining = false;

if doTraining
    detector = trainYOLOXObjectDetector(dsTrain,detector,options);
    modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss"));
    save(fullfile(tempdir,"trainedCowDetectorYOLOX"+modelDateTime+".mat"), ...
        "detector");
else
    load(fullfile(tempdir,"AerialObjectDetectionModel","trainedSmallCowDetectorYOLOX.mat"));
end

Evaluate Detector

Create an imageDatastore of full-resolution images from the test set.

imdsTest = imageDatastore(fullfile(dataDir,"test",labelMapping.testTbl.filename));
bldsTest = boxLabelDatastore(labelMapping.testTbl(:,2:end));
dsTest = transform(imdsTest,bldsTest,@(im,lbl) {im,lbl{1},lbl{2}});

Ensure that the test set contains only images with objects (cows).

testIndicesWithNonemptyGT = cellfun(@(c) ~isempty(c),labelMapping.testTbl.labels);
dsTest = subset(dsTest,testIndicesWithNonemptyGT);

Compute Performance Metrics Across Data Set

Detect the bounding boxes for all test images using the detect object function. Specify the AutoResize name-value argument as false to keep the images at full resolution at inference.

detectorResults = detect(detector,dsTest,AutoResize=false,MiniBatchSize=1,Threshold=1e-4);

Compute object detection metrics on the test set detection results by using the evaluateObjectDetection function.

metrics = evaluateObjectDetection(detectorResults,dsTest);
metrics.DatasetMetrics
ans=1×3 table
    NumObjects    mAPOverlapAvg       mAP    
    __________    _____________    __________

       3181          0.85986       {[0.8599]}

Evaluate Detector Performance Across Object Size Ranges

Define the bounding box area ranges, and evaluate object detection metrics for the defined area ranges using the metricsByArea object function. Display the object area range, the number of objects in that size range, and the average precision (AP) averaged over all overlap thresholds.

objectLabels = readall(bldsTest);
boxes = objectLabels(:,1);
boxes = vertcat(boxes{:});
boxArea = prod(boxes(:,3:4),2);
boxPrctileBoundaries = prctile(boxArea,100*(0.1:0.1:0.9));
metricsByAreaSummary = metricsByArea(metrics,[0 boxPrctileBoundaries inf]);
disp(metricsByAreaSummary(:,1:3))
       AreaRange        NumObjects    APOverlapAvg
    ________________    __________    ____________

         0        99       314          0.25543   
        99       156       304          0.74481   
       156       252       328          0.84112   
       252     399.9       326          0.88992   
     399.9       616       315          0.86841   
       616       888       321          0.90227   
       888      1225       317          0.94585   
      1225    1667.8       320          0.96497   
    1667.8      2448       316          0.97864   
      2448       Inf       320          0.97099   

The detector performance, measured using the AP averaged over all overlap thresholds, is directly proportional to the size of the cow objects in the scene. There is a considerable performance decrease at the smallest object sizes of less than 100 pixels in area, or 0.0024% of the area of the full-resolution image. The detector performs adequately (AP > 0.9) for objects that have an area only 0.015% of the full-resolution image. To improve the detection of the smallest objects, you can customize the overlap tile strategy to ensure that small objects near tile borders are more likely to be fully contained within at least one tile, further augment the data by zooming into or scaling up images or using brightness and contrast adjustment, or train your model on tiles at multiple resolutions to help the model learn to detect objects across different scales. You can also further train your model on images or tiles where detection is poor, focusing on improving the detection of extremely small objects.

References

[1] Unel, F. Ozge, Burak O. Ozkalayci, and Cevahir Cigla. “The Power of Tiling for Small Object Detection.” In 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), 582–91. Long Beach, CA, USA: IEEE, 2019. https://doi.org/10.1109/CVPRW.2019.00084.

[2] Akyon, Fatih Cagatay, Sinan Onur Altinuc, and Alptekin Temizel. “Slicing Aided Hyper Inference and Fine-Tuning for Small Object Detection.” In 2022 IEEE International Conference on Image Processing (ICIP), 966–70. Bordeaux, France: IEEE, 2022. https://doi.org/10.1109/ICIP46576.2022.9897990.

[3] Ge, Zheng, Songtao Liu, Feng Wang, Zeming Li, and Jian Sun. "YOLOX: Exceeding YOLO Series in 2021", arXiv, August 6, 2021. https://arxiv.org/abs/2107.08430.

[4] Aerial Cows Data Set. Accessed April 4, 2024. https://universe.roboflow.com/roboflow-100/aerial-cows.

Supporting functions

writeAsMAT

function writeAsMAT(ds,outputLocation)
reset(ds)
count = 1;

% Ensure a clean start each time.
if exist(outputLocation,"dir")
    rmdir(outputLocation,"s");
end

mkdir(outputLocation)
while hasdata(ds)
    imageBoxesLabelsCell = read(ds);
    save(outputLocation+filesep+count,"imageBoxesLabelsCell")
    count = count + 1;
end

end

selectBlockLocationsUsingBoxes

function [blsOverall,boxes] = selectBlockLocationsUsingBoxes(bim,bboxesIn,NV)

arguments
    bim
    bboxesIn
    NV.BlockSize = bim(1).BlockSize
    NV.BlockOffsets = bim(1).BlockSize
    NV.ExcludeIncompleteBlocks = true
    NV.SelectBackgroundTiles = false
    NV.OverlapThreshold = 1.0;
end

overlapThreshold = NV.OverlapThreshold;

boxes = {};
labelsOut = {};
blockOriginsOverall = [];
imageIndexOverall = [];
cropRect = images.spatialref.Rectangle([0.5 NV.BlockSize(2)+0.5],[0.5 NV.BlockSize(1)+0.5]);

for idx = 1:length(bim)
    blsAll = selectBlockLocations(bim(idx),BlockSize=NV.BlockSize,BlockOffsets=NV.BlockOffsets,ExcludeIncompleteBlocks=NV.ExcludeIncompleteBlocks);
    bboxes = bboxesIn{idx,1};
    labels = bboxesIn{idx,2};

    % Convert each block into a bounding box (x,y,w,h)
    blockBboxes = blsAll.BlockOrigin(:,1:2); % Already in XY
    blockBboxes(:,3) = blsAll.BlockSize(2);  % BlockSize is in R/C
    blockBboxes(:,4) = blsAll.BlockSize(1);

    if ~isempty(bboxes)

        % Determine the pairwise intersection of blocks and boxes in the
        % full-resolution image.
        pairwiseIntersectionBlocksByBoxes = rectint(blockBboxes,bboxes);

        % Normalize the intersection of each box with the block
        % rectangle to determine which box rectangles partially overlap at
        % the block edges.
        areaOfEachBoxInScene = bboxes(:,3).*bboxes(:,4);
        overlapRatioBlocksByBoxes = pairwiseIntersectionBlocksByBoxes ./ areaOfEachBoxInScene';

        if ~NV.SelectBackgroundTiles
            % A valid block is a block without any partially
            % overlapping blocks with overlap < threshold;
            selectBlockMask = any(overlapRatioBlocksByBoxes >= overlapThreshold,2);
        else
            % Blocks in which none of the GT have intersection with the box
            selectBlockMask = ~any(pairwiseIntersectionBlocksByBoxes,2);
        end

        % Build a numValidBlocks-by-numBoxes logical matrix tracking the set of valid
        % boxes within each block
        validBoxSetInEachValidBlock = overlapRatioBlocksByBoxes(selectBlockMask,:) >= overlapThreshold;

        % Create new bls with blocks that intersect at least one object.
        blockOrigins = blsAll.BlockOrigin(selectBlockMask,:);
        imgIndex = idx*ones([size(blockOrigins,1) 1]);

        boxInds = cell([size(validBoxSetInEachValidBlock,1) 1]);
        for validBlockInd = 1:length(boxInds)
            boxInds{validBlockInd} = find(validBoxSetInEachValidBlock(validBlockInd,:));
        end

        boxesInThisImage = cellfun(@(c) bboxes(c,:),boxInds,UniformOutput=false);
        labelsInThisImage = cellfun(@(c) labels(c,:),boxInds,UniformOutput=false);

        boxesInThisImage = adjustBoxLocationsBasedOnBlockOrigin(boxesInThisImage,blockOrigins(:,1:2));

        % Since partial overlap of boxes with a block is present,
        % trim portions of boxes that lie outside the bounds.
        if (overlapThreshold < 1) && ~NV.SelectBackgroundTiles
            % Set threshold to 0 because boxes with overlap smaller than overlapThreshold.
            % have been discarded.
            boxesInThisImage = cellfun(@(c) bboxcrop(c,cropRect,OverlapThreshold=eps),boxesInThisImage,UniformOutput=false);
        end

        boxes = vertcat(boxes,boxesInThisImage);
        labelsOut = vertcat(labelsOut,labelsInThisImage);

        blockOriginsOverall = vertcat(blockOriginsOverall,blockOrigins);
        imageIndexOverall = vertcat(imageIndexOverall,imgIndex);

    end

end

% Include the boxes and labels together in boxLabelDatastore table form.
boxes = table(boxes,labelsOut);

blsOverall = blockLocationSet(imageIndexOverall,blockOriginsOverall,blsAll.BlockSize);

end

adjustBoxLocationsBasedOnBlockOrigin

function boxesNew = adjustBoxLocationsBasedOnBlockOrigin(boxes,blockUpperLeft)
boxesNew = boxes;
    for idx = 1:length(boxesNew)
        shift = blockUpperLeft(idx,:) - 1;
        newBoxes = boxes{idx} - [shift 0 0];
        boxesNew{idx} = newBoxes;
    end
end

augmentData

The augmentData function randomly applies horizontal flipping and scaling to pairs of images and bounding boxes. The function clips boxes outside the bounds if the overlap is above 0.25.

function data = augmentData(A)
data = cell(size(A));
for ii = 1:size(A,1)
    I = A{ii,1};
    bboxes = A{ii,2};
    labels = A{ii,3};
    sz = size(I);

    % Randomly flip image.
    tform = randomAffine2d(XReflection=true,YReflection=true,Scale=[1.0 1.1]);
    rout = affineOutputView(sz,tform,BoundsStyle="centerOutput");
    I = imwarp(I,tform,OutputView=rout);

    % Apply the same transform to bounding boxes.
    [bboxes,indices] = bboxwarp(bboxes,tform,rout,OverlapThreshold=0.25);
    labels = labels(indices);

    % Return original data only when all boxes are removed by warping.
    if isempty(indices)
        data(ii,:) = A(ii,:);
    else
        data(ii,:) = {I bboxes labels};
    end
end
end

See Also

| | | | (Deep Learning Toolbox) |

Related Topics