Main Content

trainYOLOXObjectDetector

Train YOLOX object detector

Since R2023b

Description

detector = trainYOLOXObjectDetector(trainingData,detectorIn,options) trains a YOLOX object detector using the untrained or pretrained YOLOX network, specified by detectorIn. The options input specifies the network training parameters.

You can also use this syntax for fine-tuning a pretrained YOLOX object detector.

Note

This functionality requires Deep Learning Toolbox™ and the Automated Visual Inspection Library for Computer Vision Toolbox™. You can install the Automated Visual Inspection Library for Computer Vision Toolbox from Add-On Explorer. For more information about installing add-ons, see Get and Manage Add-Ons.

example

[detector,info] = trainYOLOXObjectDetector(___) returns information on training progress, such as training loss, for each iteration using the input arguments from the previous syntax.

___ = trainYOLOXObjectDetector(___,Name=Value) specifies options using one or more name-value arguments. For example, trainYOLOXObjectDetector(trainingData,ExperimentMonitor=[]) specifies not to track metrics with Experiment Manager.

Examples

collapse all

Prepare Training Data

Load a MAT file containing information about a vehicle data set to use for training into the workspace. The MAT file stores the information stored as a table. The first column contains the training images and the remaining columns contain the labeled bounding boxes for those images.

data = load("vehicleTrainingData.mat");
trainingData = data.vehicleTrainingData;

Specify the directory that contains the training sample image data files and box labels. Add the full path to the filenames in the training data.

dataDir = fullfile(toolboxdir("vision"),"visiondata");
trainingData.imageFilename = fullfile(dataDir,trainingData.imageFilename);

Create an imageDatastore object using the files from the table.

imds = imageDatastore(trainingData.imageFilename);

Create a boxLabelDatastore object using the label columns from the table.

blds = boxLabelDatastore(trainingData(:,2:end));

Combine the image datastore and box label datastore.

ds = combine(imds,blds);

Specify the input size to use for resizing the training images.

inputSize = [128 228 3];

Partition Data

Split the data set into training and validation sets. Allocate approximately 80% of the data for training, and the rest for validation.

numImages = numpartitions(ds);
shuffledIndices = randperm(numImages);
dsTrain = subset(ds,shuffledIndices(1:round(0.8*numImages)));
dsVal = subset(ds,shuffledIndices(numpartitions(dsTrain)+1:end));

Define YOLOX Object Detector Network Architecture

Create a YOLOX object detector by using the yoloxObjectDetector function. Specify "tiny-coco" as the pretrained network to use as the base network. Specify the class names and the network input size.

classes = {'vehicle'};
detector = yoloxObjectDetector("tiny-coco",classes,InputSize=inputSize);

Train YOLOX Object Detection Network

Specify network training options using the trainingOptions (Deep Learning Toolbox) function. Use the mAPObjectDetectionMetric object to track the mean average precision (mAP) metric when you train the detector.

options = trainingOptions("sgdm", ...
    InitialLearnRate=0.001, ...
    MiniBatchSize=16,...
    MaxEpochs=4, ...
    ResetInputNormalization=false, ...
    Metrics=mAPObjectDetectionMetric(Name="mAP50"), ...
    ObjectiveMetricName="mAP50", ...
    ValidationData=dsVal, ...
    ValidationFrequency=20, ...
    VerboseFrequency=2);

Train the pretrained YOLOX network on the new data set by using the trainYOLOXObjectDetector function.

trainedDetector = trainYOLOXObjectDetector(dsTrain,detector,options);
 
    Epoch    Iteration    TimeElapsed    LearnRate    TrainingYoloXLoss    ValidationYoloXLoss    Validationmap50
    _____    _________    ___________    _________    _________________    ___________________    _______________
      1          2         00:00:20        0.001           5.1181                                                
      1          4         00:00:25        0.001           5.0637                                                
      1          6         00:00:29        0.001           2.9549                                                
      1          8         00:00:33        0.001           4.5492                                                
      1         10         00:00:37        0.001           3.3555                                                
      1         12         00:00:41        0.001           3.1534                                                
      1         14         00:00:46        0.001            2.812                                                
      2         16         00:00:50        0.001           2.4134                                                
      2         18         00:00:55        0.001            2.203                                                
      2         20         00:00:59        0.001           2.1435                2.4607               0.94962    
      2         22         00:01:20        0.001            2.276                                                
      2         24         00:01:25        0.001           2.4464                                                
      2         26         00:01:31        0.001           2.3203                                                
      2         28         00:01:36        0.001           2.0818                                                
      3         30         00:01:42        0.001           1.9096                                                
      3         32         00:01:48        0.001           2.1849                                                
      3         34         00:01:54        0.001            1.888                                                
      3         36         00:01:59        0.001           1.9287                                                
      3         38         00:02:04        0.001            1.944                                                
      3         40         00:02:09        0.001           2.3037                2.1339               0.96241    
      3         42         00:02:32        0.001           1.7938                                                
      4         44         00:02:37        0.001           2.3005                                                
      4         46         00:02:42        0.001           1.9379                                                
      4         48         00:02:46        0.001           2.0413                                                
      4         50         00:02:50        0.001           1.9138                                                
      4         52         00:02:54        0.001           1.9997                                                
      4         54         00:02:58        0.001           1.9646                                                
      4         56         00:03:03        0.001           2.0116                2.1982               0.96451    

Detect Vehicles in Test Image

Read a test image.

img = imread("highway.png");

Use the fine-tuned YOLOX object detector to detect vehicles in the test image and display the detection results.

[boxes,scores,labels] = detect(trainedDetector,img,Threshold=0.5);
detectedImage = insertObjectAnnotation(img,"rectangle",boxes,labels);
figure
imshow(detectedImage)

Input Arguments

collapse all

Labeled ground truth images, specified as a datastore. You must configure the datastore so that calling it with the read and readall functions returns a cell array or table with these three columns:

Image DataBounding BoxesBounding Box Labels
Single-cell cell array, which contains the input image.

Single-cell array, which contains the bounding boxes, defined in spatial coordinates as an M-by-4 numeric matrix with rows of the form [x y w h], where:

  • M is the number of axis-aligned rectangles.

  • x and y specify the upper-left corner of the rectangle.

  • w specifies the width of the rectangle, which is its length along the x-axis.

  • h specifies the height of the rectangle, which is its length along the y-axis.

Single-cell cell array, which contains an M-by-1 categorical vector of object class names. All the categorical data returned by the datastore must use the same categories.

Use the combine function on two datastores to create a datastore that returns these three data columns using read. The datastore that returns the first column of data should be an ImageDatastore, while the datastore that returns the second and third columns of data should be a boxLabelDatastore.

Pretrained or untrained YOLOX object detector, specified as a yoloxObjectDetector object.

Training options, specified as a TrainingOptionsSGDM, TrainingOptionsRMSProp, or TrainingOptionsADAM object returned by the trainingOptions (Deep Learning Toolbox) function. To specify the solver name and other options for network training, use the trainingOptions (Deep Learning Toolbox) function.

Name-Value Arguments

Specify optional pairs of arguments as Name1=Value1,...,NameN=ValueN, where Name is the argument name and Value is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.

Example: ExperimentMonitor=[] specifies not to track metrics with Experiment Manager.

Detector training experiment monitoring, specified as an experiments.Monitor (Deep Learning Toolbox) object for use with the Experiment Manager (Deep Learning Toolbox) app. You can use this object to track the progress of training, update information fields in the training results table, record values of the metrics used by the training, or produce training plots. For more information on using this app, see the Train Object Detectors in Experiment Manager example.

The app monitors this information during training:

  • Training loss at each iteration

  • Learning rate at each iteration

  • Validation loss at each iteration, if the options input contains validation data

Subnetwork to freeze during training, specified as one of these values:

  • "none" — Do not freeze a subnetwork.

  • "backbone" — Freeze the feature extraction subnetwork, including the layers following the region of interest (ROI) align layer.

The weight of layers in frozen subnetworks does not change during training.

Output Arguments

collapse all

Trained YOLOX object detector, returned as a yoloxObjectDetector object. You can train a YOLOX object detector to detect multiple object classes.

Training progress information, returned as a structure array with these fields. Each field corresponds to a stage of training.

  • TrainingLoss — Training loss at each iteration. The trainYOLOXObjectDetector function uses mean square error for computing bounding box regression loss and cross-entropy for computing classification loss.

  • BaseLearnRate — Learning rate at each iteration.

  • OutputNetworkIteration — Iteration number of the returned network.

  • ValidationLoss — Validation loss at each iteration.

  • FinalValidationLoss — Final validation loss at the end of the training.

Each field is a numeric vector with one element per training iteration. The function returns a value of NaN for iterations at which it does not calculate that value. The structure contains the ValidationLoss and FinalValidationLoss fields only when options specifies validation data.

Tips

  • To generate the labeled ground truth image data, use the Image Labeler or Video Labeler app. To create a table of training data from the generated ground truth, use the objectDetectorTrainingData function.

  • To help improve prediction accuracy, increase the number of images you use to train the network. You can expand the training data set using data augmentation. For information on how to apply data augmentation for preprocessing, see Preprocess Images for Deep Learning (Deep Learning Toolbox).

Version History

Introduced in R2023b