Main Content

trainHRNetObjectKeypointDetector

Train HRNet object keypoint detector

Since R2024a

    Description

    trainedKeypointDetector = trainHRNetObjectKeypointDetector(trainingData,keypointDetector,options) trains an object keypoint detector using the specified high resolution deep learning network (HRNet) keypointDetector. You can specify keypointDetector as a pretrained or custom HRNet object keypoint detector. The options input specifies training parameters for the keypoint detection network.

    You can also use this syntax to fine-tune a pretrained HRNet object keypoint detector.

    example

    trainedKeypointDetector = trainHRNetObjectKeypointDetector(trainingData,checkpoint,options) resumes training from the saved detector checkpoint.

    You can use this syntax to:

    • Add more training data and continue training a keypoint detector.

    • Improve training accuracy by increasing the maximum number of iterations.

    [trainedKeypointDetector,info] = trainHRNetObjectKeypointDetector(___) returns information on the training progress, such as the training loss and learning rate for each iteration, using any combination of input arguments from previous syntaxes.

    [___] = trainHRNetObjectKeypointDetector(___,Name=Value) specifies options using one or more name-value arguments. For example, ExperimentMonitor=[] specifies not to track metrics using Experiment Manager.

    Note

    This functionality requires Deep Learning Toolbox™ and the Computer Vision Toolbox™ Model for Object Keypoint Detection. You can install the Computer Vision Toolbox Model for Object Keypoint Detection from the Add-On Explorer. For more information about installing add-ons, see Get and Manage Add-Ons.

    Examples

    collapse all

    Fine-tune a pretrained HRNet object keypoint detector to detect the keypoints of a human hand in an image. The keypoint detector uses an HRNet-W32 deep learning network, trained on the COCO keypoint detection data set.

    Download and load the hand pose data set using helperDownloadHandPoseDataset helper function. The hand pose data set is a labeled data set that contains 2500 images from the Large-Scale Multiview Hand Pose Dataset [1]. Extract a subset of the downloaded hand pose data set that contains the first 100 images. Each image in the data set contains a human hand with 21 annotated keypoints.

    downloadFolder = tempdir;
    dataset = helperDownloadHandPoseDataset(downloadFolder);
    data = load(dataset);
    handPoseDataset = data.handPoseDataset(1:100,:);

    The training data is a table with three columns in which the first, second, and third columns contain the image filenames, keypoint locations, and hand bounding boxes, respectively. Each entry in the keypoint locations column consists of an N-by-2 matrix, where N is the number of keypoints present in the hand in the corresponding image. Each image contains only one hand, which is one object. Therefore, each row represents one object in an image. Add the full data path to the locally stored hand pose data folder.

    handPoseDataset.imageFilename = fullfile(downloadFolder,"2DHandPoseDataAndGroundTruth","2DHandPoseImages",handPoseDataset.imageFilename);

    Create an ImageDatastore object for loading the image data.

    handPoseImds = imageDatastore(handPoseDataset.imageFilename);

    Create an ArrayDatastore object for loading the ground truth keypoint location data.

    handPoseArrds = arrayDatastore(handPoseDataset(:,2));

    Create a boxLabelDatastore object for loading the bounding box locations.

    handPoseBlds = boxLabelDatastore(handPoseDataset(:,3));

    Combine the image, array, and box label datastores into a single datastore.

    trainingData = combine(handPoseImds,handPoseArrds,handPoseBlds);

    Specify the keypoint classes in a human hand.

    keypointClasses = ["forefinger3","forefinger4","forefinger2","forefinger1", ...
        "middleFinger3","middleFinger4","middleFinger2","middleFinger1", ...
        "pinkyFinger3","pinkyFinger4","pinkyFinger2","pinkyFinger1", ...
        "ringFinger3","ringFinger4","ringFinger2","ringFinger1", ...
        "thumb3","thumb4","thumb2","thumb1","wrist"]';

    Create an hrnetObjectKeypointDetector object, and configure it to detect the specified keypoints classes.

    handKeypointDetector = hrnetObjectKeypointDetector("human-full-body-w32",keypointClasses);
    handKeypointDetector.Network
    ans = 
      dlnetwork with properties:
    
             Layers: [1036×1 nnet.cnn.layer.Layer]
        Connections: [1196×2 table]
         Learnables: [1170×3 table]
              State: [584×3 table]
         InputNames: {'input_1'}
        OutputNames: {'finallayer'}
        Initialized: 1
    
      View summary with summary.
    
    

    Specify training options for the hand keypoint detector.

    options = trainingOptions("adam", ...
        MaxEpochs=20, ...
        InitialLearnRate=0.001, ...
        MiniBatchSize=16, ...
        LearnRateSchedule="piecewise", ...
        LearnRateDropFactor=0.1, ...
        LearnRateDropPeriod=12, ...
        VerboseFrequency=25, ...
        BatchNormalizationStatistics="moving", ...
        ResetInputNormalization=false);

    Fine-tune the pretrained HRNet object keypoint detector on the new data set by using the trainHRNetObjectKeypointDetector function.

    [trainedHandKeypointDetector,info] = trainHRNetObjectKeypointDetector(trainingData,handKeypointDetector,options);
    *************************************************************************
    Training a HRNet Object Keypoint Detector for the following keypoint classes:
    
    * forefinger3
    * forefinger4
    * forefinger2
    * forefinger1
    * middleFinger3
    * middleFinger4
    * middleFinger2
    * middleFinger1
    * pinkyFinger3
    * pinkyFinger4
    * pinkyFinger2
    * pinkyFinger1
    * ringFinger3
    * ringFinger4
    * ringFinger2
    * ringFinger1
    * thumb3
    * thumb4
    * thumb2
    * thumb1
    * wrist
    
     
        Epoch    Iteration    TimeElapsed    LearnRate    TrainingLoss
        _____    _________    ___________    _________    ____________
          4         25         00:08:40        0.001       0.0018743  
          8         50         00:16:21        0.001       0.0017478  
         11         75         00:22:27        0.001       0.0014061  
         15         100        00:28:27       0.0001       0.0012995  
         18         125        00:34:44       0.0001       0.0013525  
    
    *************************************************************************
    Keypoint detector training complete.
    *************************************************************************
    

    Read a test image. Use the trained HRNet hand keypoint detector to detect hand keypoints and display the detection results.

    I = imread("test.jpg");
    bbox= [185 156 249 211];
    predictedKeypoints = detect(trainedHandKeypointDetector,I,bbox);
    outputImg = insertObjectKeypoints(I,predictedKeypoints, ...
        KeypointColor="yellow",KeypointSize=3,LineWidth=3);
    outputImg = insertShape(outputImg,rectangle=bbox);
    figure
    imshow(outputImg)

    Supporting Function

    function dataset = helperDownloadHandPoseDataset(downloadFolder)
    dataFilename = "2DHandPoseDataAndGroundTruth.zip";
    dataAndImageUrl = "https://ssd.mathworks.com/supportfiles/vision/data/2DHandPose/" + dataFilename;
    zipFile = fullfile(downloadFolder,dataFilename);
    if ~exist(zipFile,"file")
        disp("Downloading hand pose dataset (98 MB)...")
        websave(zipFile,dataAndImageUrl);
    end
    unzip(zipFile,downloadFolder)
    dataset = fullfile(downloadFolder,"2DHandPoseDataAndGroundTruth","2DHandPoseGroundTruth.mat");
    end

    References

    [1] Gomez-Donoso, Francisco, Sergio Orts-Escolano, and Miguel Cazorla. "Large-Scale Multiview 3D Hand Pose Dataset." Image and Vision Computing 81 (January 2019): 25–33. https://doi.org/10.1016/j.imavis.2018.12.001.

    Input Arguments

    collapse all

    Labeled ground truth images, specified as a datastore. Your data must be set up so that using the read or readall function on the datastore returns a cell array with three columns in the order {ImageData,Keypoint,BoundingBox}. This table describes the format of each cell in a row.

    ImageDataKeypointBoundingBox

    Grayscale, RGB, or multi-channel image that serves as a network input, specified as an H-by-W numeric matrix, H-by-W-by-3 numeric array, or H-by-W-by-C numeric array, respectively, where:

    • H is the height of the image.

    • W is the width of the image.

    • C is number of channels in the image.

    Keypoint locations, defined in spatial coordinates as an N-by-2 or N-by-3 numeric matrix with rows of the form [x y] or [x y v], respectively, where:

    • N is the number of keypoint classes.

    • x and y specify the spatial coordinates of a keypoint.

    • v specifies the visibility of a keypoint.

    Bounding box, defined in spatial coordinates as a 1-by-4 numeric vector of the form [x y w h], where:

    • x and y specify the upper-left corner of the rectangle.

    • w specifies the width of the rectangle, which is its length along the x-axis.

    • h specifies the height of the rectangle, which is its length along the y-axis.

    Use the combine function on an ImageDatastore, an arrayDatastore and a boxLabelDatastore datastore, in that order, to create a combined datastore that returns these three data columns using read.

    Pretrained or untrained HRNet object keypoint detector, specified as an hrnetObjectKeypointDetector object.

    Training options, specified as a TrainingOptionsSGDM, TrainingOptionsRMSProp, or TrainingOptionsADAM object returned by the trainingOptions (Deep Learning Toolbox) function. To specify the solver name and other options for network training, use the trainingOptions function. You must set the BatchNormalizationStatistics property of the object to "moving".

    Saved detector checkpoint, specified as an hrnetObjectKeypointDetector object. To periodically save a detector checkpoint during training as a MAT file, specify a location for the file using the CheckpointPath property of the training options object options. To control how frequently the detector saves checkpoints, use the CheckPointFrequency and CheckPointFrequencyUnit properties of the training options object.

    To load a checkpoint for a previously trained detector, load the MAT file from the checkpoint path. For example, this code loads the checkpoint MAT file of a detector from the "checkpath" folder in the current working directory to which the detector saves checkpoints during training.

    data = load("checkpath/net_checkpoint__6__2023_11_17__16_03_08.mat");
    checkpoint = data.net;

    The name of each MAT file includes the iteration number and timestamp at which the detector saves the checkpoint. The file stores the detector in the net variable. To continue training the network, specify the detector extracted from the file to the trainHRNetObjectKeypointDetector function.

    HRNetDetector = trainHRNetObjectKeypointDetector(trainingData,checkpoint,options);

    Name-Value Arguments

    Specify optional pairs of arguments as Name1=Value1,...,NameN=ValueN, where Name is the argument name and Value is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.

    Example: trainHRNetObjectKeypointDetector(trainingData,ExperimentMonitor=[]) specifies not to track metrics with Experiment Manager.

    Detector training experiment monitoring, specified as an experiments.Monitor (Deep Learning Toolbox) object for use with the Experiment Manager (Deep Learning Toolbox) app. You can use this object to track the progress of training, update information fields in the training results table, record values of the metrics used by the training, or produce training plots. For more information on using this app, see the Train Object Detectors in Experiment Manager example.

    The app monitors this information during training:

    • Training loss at each iteration

    • Learning rate at each iteration

    • Validation loss at each iteration, if the options input contains validation data

    Output Arguments

    collapse all

    Trained HRNet object keypoint detector, returned as an hrnetObjectKeypointDetector object.

    Training progress information, returned as a structure array with these fields. Each field corresponds to a stage of training.

    • TrainingLoss — Training loss at each iteration. The trainHRNetObjectKeypointDetector function uses mean square error to compute bounding box regression loss and cross-entropy to compute classification loss.

    • BaseLearnRate — Learning rate at each iteration.

    • OutputNetworkIteration — Iteration number of the returned network.

    Each field is a numeric vector with one element per training iteration. If the function does not calculate a metric at a specific iteration, the corresponding element of that vector has a value of NaN.

    Version History

    Introduced in R2024a