Detect Small Objects Using Tiled Training of YOLOX Network
Detecting small objects in full resolution images is a modern challenge in object detection [1]. An object can be small in terms of the size of the object, in pixels, relative to the overall dimensions of the image, or the number of pixels the object contains. Using full-resolution images to train a network for small object detection often becomes computationally intensive, and downsampling training data that contains objects spanning a small number of pixels can lead to significant information loss in detection and training. To address this issue, you can tile your training images. In this tiled training approach, you split the images in the data set into full-resolution segments, or tiles, at training time, which preserves the full spatial resolution in the trained detector [2].
In this example, you train a YOLOX object detection network using tiled images, and use this trained network to detect small cows in aerial images of agriculture, using one-shot inference at full image resolution. Whereas you typically perform inference at the resolution of the tile, requiring multiple inference calls per image, the one-shot inference approach in this example requires no additional post-processing, enabling you to deploy it directly on hardware using code generation. Full-resolution inference can be fast, even on relatively large input images, primarily constrained by the amount of GPU or CPU memory available for the computation.
Download Pretrained YOLOX Network
By default, this example downloads a pretrained version of the YOLOX object detector [3]. You can use the pretrained network to run the entire example without waiting for training to complete. The fully convolutional architecture of YOLOX enables tiled training and full-sized inference with one-shot small object detection.
unzip("https://ssd.mathworks.com/supportfiles/vision/data/aerialObjectDetectionModel.zip",tempdir); load(fullfile(tempdir,"AerialObjectDetectionModel","trainedSmallCowDetectorYOLOX.mat"));
Download Aerial Cows Data Set
This example uses the aerial cows data set [4]. The data set contains 1,723 images of agricultural terrain photographed from an aerial perspective, many of which feature cows. The data contains 1,084 training images. Each of the images in the data set is approximately 1530-by-2720-by-3 pixels. The cows in the scene are very small relative to the image size. The median size of the cow bounding boxes in the ground truth data is 30-by-32 pixels.
Specify dataDir
as the location of the data set. Download the data set and unzip the contents of the folder into the specified location.
unzip("https://ssd.mathworks.com/supportfiles/vision/data/aerialCowsDataset.zip",tempdir) dataDir = fullfile(tempdir,"aerialCows");
Detect Small Objects Using Pretrained Network
Read a sample image from the data set.
I = imread("testAerialImage.jpg");
Predict the mask, labels, and confidence scores for each object instance using the detect
object function. Specify AutoResize
as false
to perform the inference at full image size.
[boxes,scores,labels] = detect(detector,I,Threshold=0.1,AutoResize=false);
Display the object annotations overlaid on the image using the insertObjectAnnotation
function.
figure
imshow(insertObjectAnnotation(I,"rectangle",boxes,labels))
Prepare Data for Training
Partition Data into Training and Validation Sets
Prior to partitioning the data, set the global random state to 0
to ensure reproducibility.
rng(0)
Create an imageDatastore
object that holds the training set, from the train
folder of the downloaded aerial cow data set. Create a datastore, trainDS
, that associates images with the corresponding ground truth box labels for cows in the training set.
labelMapping = load(fullfile(dataDir,"aerialCowLabels.mat")); imds = imageDatastore(fullfile(dataDir,"train",labelMapping.trainTbl.filename)); blds = boxLabelDatastore(labelMapping.trainTbl(:,2:end)); trainDS = transform(imds,blds,@(img,lbl) {img lbl{1} lbl{2}});
Create an imageDatastore
object that holds the validation set, from the valid
folder of the downloaded aerial cow data set. Create a datastore, valDS
, that associates images with the corresponding ground truth box labels for cows in the validation set.
imdsVal = imageDatastore(fullfile(dataDir,"valid",labelMapping.valTbl.filename));
bldsVal = boxLabelDatastore(labelMapping.valTbl(:,2:end));
valDS = transform(imdsVal,bldsVal,@(img,lbl) {img lbl{1} lbl{2}});
Determine Object Size Distribution
Display a scatter plot of the object sizes to determine their distribution. The median width and height of the boxes are 30 and 32 pixels, respectively. The minimum width and height of the boxes are 1 and 4 pixels, respectively.
trainingBoxes = readall(blds); overallBoxVector = vertcat(trainingBoxes{:,1}); figure plot(overallBoxVector(:,3),overallBoxVector(:,4),"x") xlabel("Bounding Box Widths") ylabel("Bounding Box Heights")
Split Training and Validation Data into Tiles
Create a blockedImage
object from the images in the training set. The blockedImage
object represents a very large image as a collection of smaller blocks.
bim = blockedImage(imds.Files);
Select image tiles of size [512 512 3] from the dataset with a 50% overlap ratio using the selectBlockLocationsUsingBoxes
helper function. Some tiles contain box labels that are partially cut off at the tile edges. In the case of box labels that are partially contained, the function selects tiles in which the ratio of the box area within the image bounds to the total box area is at least 50% for all boxes in the tile.
[bls,boxLabelsOut] = selectBlockLocationsUsingBoxes(bim,readall(blds),BlockSize=[512 512 3],BlockOffsets=[256 256],OverlapThreshold=0.5);
To ensure that the detector encounters a variety of background content, including features such as mountains or clouds that might be spatially far from where cows appear in the scene, select 30% as many background tiles as tiles that contain box labels.
[blsBackground,boxLabelsOutBackground] = selectBlockLocationsUsingBoxes(bim,readall(blds),BlockSize=[512 512 3],BlockOffsets=[512 512],SelectBackgroundTiles=true); indicesToSelect = randperm(size(boxLabelsOutBackground,1)); indicesToSelect = indicesToSelect(1:round(0.3*size(boxLabelsOut,1)));
Create a blockLocationSet
object that contains the specified block and their locations.
blsOverall = blockLocationSet(vertcat(bls.ImageNumber,blsBackground.ImageNumber(indicesToSelect)), ... vertcat(bls.BlockOrigin,blsBackground.BlockOrigin(indicesToSelect,:)), ... [512 512 3]);
Create a blockedImageDatastore
object for the tiled training images. Specify the blocks to include in the datastore as the new blockLocationSet
object blsOverall
.
bimds = blockedImageDatastore(bim,BlockLocationSet=blsOverall);
Create a boxLabelDatastore
object for the bounding box label data corresponding to the tiled training images.
bldsForBlocks = boxLabelDatastore(vertcat(boxLabelsOut,boxLabelsOutBackground)); trainingSetBlocked = transform(bimds,bldsForBlocks,@(im,boxes) {im{1} boxes{1} boxes{2}});
Display a sample image tile from the tiled training set.
x = preview(trainingSetBlocked);
figure
imshow(insertObjectAnnotation(x{1},"rectangle",x{2},x{3}))
Repeat the process of creating a tiled datastore of images and labels for the validation set.
bimVal = blockedImage(imdsVal.Files); [blsVal,boxLabelsOutVal] = selectBlockLocationsUsingBoxes(bimVal,readall(bldsVal),BlockSize=[512 512],BlockOffsets=[256 256]); bimdsVal = blockedImageDatastore(bimVal,BlockLocationSet=blsVal); bldsForBlocksVal = boxLabelDatastore(boxLabelsOutVal); valSetBlocked = transform(bimdsVal,bldsForBlocksVal,@(im,boxes) {im{1} boxes{1} boxes{2}});
Display a sample image tile from the tiled validation set.
x = preview(valSetBlocked);
figure
imshow(insertObjectAnnotation(x{1},"rectangle",x{2},x{3}))
Write Data to Disk to Improve Training Speed
To increase training speed, create an on-disk representation of the blocked image data. To keep the sections of each image for each tile in memory, write each tile with its corresponding box labels to disk so that you can read the tiled training set into memory as efficiently as possible. To write a large set of tiled images to disk, ensure that you have 8.6 GB of additional disk space available. You can perform training without writing the images to disk at the expense of significantly longer training time.
Write the image and corresponding box label information for each tile to disk. Create separate MAT files for the training and validation sets by using the writeAsMAT
helper function.
outputLocationTrain = fullfile(dataDir,"trainingSetTiled"); writeAsMAT(trainingSetBlocked,outputLocationTrain) outputLocationVal = fullfile(dataDir,"validationSetTiled"); writeAsMAT(valSetBlocked,outputLocationVal)
Configure Reading of Tiled Data
Read the training and validation tiled data set from disk at training time by using fileDatastore
objects. The datastores read each MAT file, and return the labeled object detection data.
dsTrain = fileDatastore(outputLocationTrain,ReadFcn=@load);
dsTrain = transform(dsTrain,@(x) x.imageBoxesLabelsCell);
dsVal = fileDatastore(outputLocationVal,ReadFcn=@load);
dsVal = transform(dsVal,@(x) x.imageBoxesLabelsCell);
subsetIndices = randperm(numpartitions(dsVal));
dsVal = subset(dsVal,subsetIndices(1:round(0.5*numpartitions(dsVal)))); % Half of the data for speed
Augment Training Data
Augment the training data by using the augmentData
helper function, which applies a horizontal and vertical reflection and resizes the input data by a scale factor in the range [1, 1.1] to the input data. To ensure the learned filters in the object detector generalize well to full-sized inference, use augmentation that does not require fill values to define out-of-bounds points when resampling.
dsTrain = transform(dsTrain,@augmentData);
Define YOLOX Object Detector Network Architecture
Create the YOLOX object detector by using the yoloxObjectDetector
function. Specify the pretrained network as "small-coco"
, which uses CSP-DarkNet-53 as the base network and is trained on the COCO data set. Specify the class name and a network input size equal to the tile size used in the tiled training set to ensure that the network does not perform any resizing during training.
detector = yoloxObjectDetector("small-coco","cow",InputSize=[512 512 3]);
Specify Training Options
Specify network training options using the trainingOptions
(Deep Learning Toolbox) function. Specify the Metrics
argument using the mAPObjectDetectionMetric
function to monitor the mean average precision (mAP) of the detector during training. Specify the OutputNetwork
name-value argument as "best-validation"
and the ObjectiveMetricName
name-value argument as "mAP50"
to return the trained network corresponding to the training iteration with the maximal mAP value.
options = trainingOptions("sgdm", ... InitialLearnRate=5e-4, ... LearnRateSchedule="piecewise", ... LearnRateDropFactor=0.98, ... LearnRateDropPeriod=1, ... MiniBatchSize=50, ... MaxEpochs=20, ... Shuffle="every-epoch", ... VerboseFrequency=1, ... ValidationFrequency=200, ... ValidationData=dsVal, ... ValidationPatience=5, ... OutputNetwork = "best-validation", ... Metrics=mAPObjectDetectionMetric(Name="mAP50"), ... L2Regularization=5e-4, ... ObjectiveMetricName="mAP50");
Train Detector
To train the detector, set doTraining
to true
. Train the detector by using the trainYOLOXObjectDetector
function.
Train on one or more GPUs, if they are available. Using a GPU requires a Parallel Computing Toolbox™ license and a CUDA® enabled NVIDIA® GPU. For more information, see GPU Computing Requirements (Parallel Computing Toolbox). Training takes about 1 hour on an NVIDIA Titan RTX™ with 24 GB of memory.
doTraining = false; if doTraining detector = trainYOLOXObjectDetector(dsTrain,detector,options); modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss")); save(fullfile(tempdir,"trainedCowDetectorYOLOX"+modelDateTime+".mat"), ... "detector"); else load(fullfile(tempdir,"AerialObjectDetectionModel","trainedSmallCowDetectorYOLOX.mat")); end
Evaluate Detector
Create an imageDatastore
of full-resolution images from the test set.
imdsTest = imageDatastore(fullfile(dataDir,"test",labelMapping.testTbl.filename));
bldsTest = boxLabelDatastore(labelMapping.testTbl(:,2:end));
dsTest = transform(imdsTest,bldsTest,@(im,lbl) {im,lbl{1},lbl{2}});
Ensure that the test set contains only images with objects (cows).
testIndicesWithNonemptyGT = cellfun(@(c) ~isempty(c),labelMapping.testTbl.labels); dsTest = subset(dsTest,testIndicesWithNonemptyGT);
Compute Performance Metrics Across Data Set
Detect the bounding boxes for all test images using the detect
object function. Specify the AutoResize
name-value argument as false
to keep the images at full resolution at inference.
detectorResults = detect(detector,dsTest,AutoResize=false,MiniBatchSize=1,Threshold=1e-4);
Compute object detection metrics on the test set detection results by using the evaluateObjectDetection
function.
metrics = evaluateObjectDetection(detectorResults,dsTest); metrics.DatasetMetrics
ans=1×3 table
NumObjects mAPOverlapAvg mAP
__________ _____________ __________
3181 0.85986 {[0.8599]}
Evaluate Detector Performance Across Object Size Ranges
Define the bounding box area ranges, and evaluate object detection metrics for the defined area ranges using the metricsByArea
object function. Display the object area range, the number of objects in that size range, and the average precision (AP) averaged over all overlap thresholds.
objectLabels = readall(bldsTest); boxes = objectLabels(:,1); boxes = vertcat(boxes{:}); boxArea = prod(boxes(:,3:4),2); boxPrctileBoundaries = prctile(boxArea,100*(0.1:0.1:0.9)); metricsByAreaSummary = metricsByArea(metrics,[0 boxPrctileBoundaries inf]); disp(metricsByAreaSummary(:,1:3))
AreaRange NumObjects APOverlapAvg ________________ __________ ____________ 0 99 314 0.25543 99 156 304 0.74481 156 252 328 0.84112 252 399.9 326 0.88992 399.9 616 315 0.86841 616 888 321 0.90227 888 1225 317 0.94585 1225 1667.8 320 0.96497 1667.8 2448 316 0.97864 2448 Inf 320 0.97099
The detector performance, measured using the AP averaged over all overlap thresholds, is directly proportional to the size of the cow objects in the scene. There is a considerable performance decrease at the smallest object sizes of less than 100 pixels in area, or 0.0024% of the area of the full-resolution image. The detector performs adequately (AP > 0.9) for objects that have an area only 0.015% of the full-resolution image. To improve the detection of the smallest objects, you can customize the overlap tile strategy to ensure that small objects near tile borders are more likely to be fully contained within at least one tile, further augment the data by zooming into or scaling up images or using brightness and contrast adjustment, or train your model on tiles at multiple resolutions to help the model learn to detect objects across different scales. You can also further train your model on images or tiles where detection is poor, focusing on improving the detection of extremely small objects.
References
[1] Unel, F. Ozge, Burak O. Ozkalayci, and Cevahir Cigla. “The Power of Tiling for Small Object Detection.” In 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), 582–91. Long Beach, CA, USA: IEEE, 2019. https://doi.org/10.1109/CVPRW.2019.00084.
[2] Akyon, Fatih Cagatay, Sinan Onur Altinuc, and Alptekin Temizel. “Slicing Aided Hyper Inference and Fine-Tuning for Small Object Detection.” In 2022 IEEE International Conference on Image Processing (ICIP), 966–70. Bordeaux, France: IEEE, 2022. https://doi.org/10.1109/ICIP46576.2022.9897990.
[3] Ge, Zheng, Songtao Liu, Feng Wang, Zeming Li, and Jian Sun. "YOLOX: Exceeding YOLO Series in 2021", arXiv, August 6, 2021. https://arxiv.org/abs/2107.08430.
[4] Aerial Cows Data Set. Accessed April 4, 2024. https://universe.roboflow.com/roboflow-100/aerial-cows.
Supporting functions
writeAsMAT
function writeAsMAT(ds,outputLocation) reset(ds) count = 1; % Ensure a clean start each time. if exist(outputLocation,"dir") rmdir(outputLocation,"s"); end mkdir(outputLocation) while hasdata(ds) imageBoxesLabelsCell = read(ds); save(outputLocation+filesep+count,"imageBoxesLabelsCell") count = count + 1; end end
selectBlockLocationsUsingBoxes
function [blsOverall,boxes] = selectBlockLocationsUsingBoxes(bim,bboxesIn,NV) arguments bim bboxesIn NV.BlockSize = bim(1).BlockSize NV.BlockOffsets = bim(1).BlockSize NV.ExcludeIncompleteBlocks = true NV.SelectBackgroundTiles = false NV.OverlapThreshold = 1.0; end overlapThreshold = NV.OverlapThreshold; boxes = {}; labelsOut = {}; blockOriginsOverall = []; imageIndexOverall = []; cropRect = images.spatialref.Rectangle([0.5 NV.BlockSize(2)+0.5],[0.5 NV.BlockSize(1)+0.5]); for idx = 1:length(bim) blsAll = selectBlockLocations(bim(idx),BlockSize=NV.BlockSize,BlockOffsets=NV.BlockOffsets,ExcludeIncompleteBlocks=NV.ExcludeIncompleteBlocks); bboxes = bboxesIn{idx,1}; labels = bboxesIn{idx,2}; % Convert each block into a bounding box (x,y,w,h) blockBboxes = blsAll.BlockOrigin(:,1:2); % Already in XY blockBboxes(:,3) = blsAll.BlockSize(2); % BlockSize is in R/C blockBboxes(:,4) = blsAll.BlockSize(1); if ~isempty(bboxes) % Determine the pairwise intersection of blocks and boxes in the % full-resolution image. pairwiseIntersectionBlocksByBoxes = rectint(blockBboxes,bboxes); % Normalize the intersection of each box with the block % rectangle to determine which box rectangles partially overlap at % the block edges. areaOfEachBoxInScene = bboxes(:,3).*bboxes(:,4); overlapRatioBlocksByBoxes = pairwiseIntersectionBlocksByBoxes ./ areaOfEachBoxInScene'; if ~NV.SelectBackgroundTiles % A valid block is a block without any partially % overlapping blocks with overlap < threshold; selectBlockMask = any(overlapRatioBlocksByBoxes >= overlapThreshold,2); else % Blocks in which none of the GT have intersection with the box selectBlockMask = ~any(pairwiseIntersectionBlocksByBoxes,2); end % Build a numValidBlocks-by-numBoxes logical matrix tracking the set of valid % boxes within each block validBoxSetInEachValidBlock = overlapRatioBlocksByBoxes(selectBlockMask,:) >= overlapThreshold; % Create new bls with blocks that intersect at least one object. blockOrigins = blsAll.BlockOrigin(selectBlockMask,:); imgIndex = idx*ones([size(blockOrigins,1) 1]); boxInds = cell([size(validBoxSetInEachValidBlock,1) 1]); for validBlockInd = 1:length(boxInds) boxInds{validBlockInd} = find(validBoxSetInEachValidBlock(validBlockInd,:)); end boxesInThisImage = cellfun(@(c) bboxes(c,:),boxInds,UniformOutput=false); labelsInThisImage = cellfun(@(c) labels(c,:),boxInds,UniformOutput=false); boxesInThisImage = adjustBoxLocationsBasedOnBlockOrigin(boxesInThisImage,blockOrigins(:,1:2)); % Since partial overlap of boxes with a block is present, % trim portions of boxes that lie outside the bounds. if (overlapThreshold < 1) && ~NV.SelectBackgroundTiles % Set threshold to 0 because boxes with overlap smaller than overlapThreshold. % have been discarded. boxesInThisImage = cellfun(@(c) bboxcrop(c,cropRect,OverlapThreshold=eps),boxesInThisImage,UniformOutput=false); end boxes = vertcat(boxes,boxesInThisImage); labelsOut = vertcat(labelsOut,labelsInThisImage); blockOriginsOverall = vertcat(blockOriginsOverall,blockOrigins); imageIndexOverall = vertcat(imageIndexOverall,imgIndex); end end % Include the boxes and labels together in boxLabelDatastore table form. boxes = table(boxes,labelsOut); blsOverall = blockLocationSet(imageIndexOverall,blockOriginsOverall,blsAll.BlockSize); end
adjustBoxLocationsBasedOnBlockOrigin
function boxesNew = adjustBoxLocationsBasedOnBlockOrigin(boxes,blockUpperLeft) boxesNew = boxes; for idx = 1:length(boxesNew) shift = blockUpperLeft(idx,:) - 1; newBoxes = boxes{idx} - [shift 0 0]; boxesNew{idx} = newBoxes; end end
augmentData
The augmentData
function randomly applies horizontal flipping and scaling to pairs of images and bounding boxes. The function clips boxes outside the bounds if the overlap is above 0.25.
function data = augmentData(A) data = cell(size(A)); for ii = 1:size(A,1) I = A{ii,1}; bboxes = A{ii,2}; labels = A{ii,3}; sz = size(I); % Randomly flip image. tform = randomAffine2d(XReflection=true,YReflection=true,Scale=[1.0 1.1]); rout = affineOutputView(sz,tform,BoundsStyle="centerOutput"); I = imwarp(I,tform,OutputView=rout); % Apply the same transform to bounding boxes. [bboxes,indices] = bboxwarp(bboxes,tform,rout,OverlapThreshold=0.25); labels = labels(indices); % Return original data only when all boxes are removed by warping. if isempty(indices) data(ii,:) = A(ii,:); else data(ii,:) = {I bboxes labels}; end end end
See Also
yoloxObjectDetector
| trainYOLOXObjectDetector
| detect
| evaluateObjectDetection
| trainingOptions
(Deep Learning Toolbox) | transform
Related Topics
- Getting Started with YOLOX for Object Detection
- Choose an Object Detector
- Deep Learning in MATLAB (Deep Learning Toolbox)
- Pretrained Deep Neural Networks (Deep Learning Toolbox)