Main Content

classify

Classify document using BERT document classifier

Since R2023b

    Description

    Y = classify(mdl,documents) classifies the specified documents using the BERT document classifier mdl.

    example

    Y = classify(mdl,documents,Name=Value) specifies additional options using one or more name-value arguments.

    Examples

    collapse all

    Read the training data from the factoryReports CSV file. The file contains factory reports, including a text description and categorical label for each report.

    filename = "factoryReports.csv";
    data = readtable(filename,TextType="string");
    head(data)
                                     Description                                       Category          Urgency          Resolution         Cost 
        _____________________________________________________________________    ____________________    ________    ____________________    _____
    
        "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
        "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
        "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
        "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
        "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
        "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
        "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
        "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38
    

    Convert the labels in the Category column of the table to categorical values.

    data.Category = categorical(data.Category);

    Partition the data into a training set and a test set. Specify the holdout percentage as 10%.

    cvp = cvpartition(data.Category,Holdout=0.1);
    dataTrain = data(cvp.training,:);
    dataTest = data(cvp.test,:);

    Extract the text data and labels from the tables.

    textDataTrain = dataTrain.Description;
    textDataTest = dataTest.Description;
    TTrain = dataTrain.Category;
    TTest = dataTest.Category;

    Load a pretrained BERT-Base document classifier using the bertDocumentClassifier function.

    classNames = categories(data.Category);
    mdl = bertDocumentClassifier(ClassNames=classNames)
    mdl = 
      bertDocumentClassifier with properties:
    
           Network: [1×1 dlnetwork]
         Tokenizer: [1×1 bertTokenizer]
        ClassNames: ["Electronic Failure"    "Leak"    "Mechanical Failure"    "Software Failure"]
    
    

    Specify the training options. Choosing among training options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager (Deep Learning Toolbox) app.

    • Train using the Adam optimizer.

    • Train for eight epochs.

    • For fine-tuning, lower the learning rate. Train using a learning rate of 0.0001.

    • Shuffle the data every epoch.

    • Monitor the training progress in a plot and monitor the accuracy metric.

    • Disable the verbose output.

    options = trainingOptions("adam", ...
        MaxEpochs=8, ...
        InitialLearnRate=1e-4, ...
        Shuffle="every-epoch", ...  
        Plots="training-progress", ...
        Metrics="accuracy", ...
        Verbose=false);

    Train the neural network using the trainBERTDocumentClassifier function. By default, the trainBERTDocumentClassifier function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information about supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainBERTDocumentClassifier function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.

    mdl = trainBERTDocumentClassifier(textDataTrain,TTrain,mdl,options);

    Make predictions using the test data.

    YTest = classify(mdl,textDataTest);

    Calculate the classification accuracy of the test predictions.

    accuracy = mean(TTest == YTest)
    accuracy = 0.9375
    

    Input Arguments

    collapse all

    BERT document classifier model, specified as a bertDocumentClassifier object.

    Input documents, specified as a string array, a cell array of character vectors, or a tokenizedDocument array.

    Name-Value Arguments

    Specify optional pairs of arguments as Name1=Value1,...,NameN=ValueN, where Name is the argument name and Value is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.

    Example: classify(mdl,document,MiniBatchSize=64) classifies the specified documents using mini-batches of size 64.

    Mini-batch size to use for prediction, specified as a positive integer. Larger mini-batch sizes require more memory, but can lead to faster predictions.

    Data Types: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64

    Performance optimization, specified as one of these values:

    • "auto" — Automatically apply a number of optimizations that are suitable for the input network and hardware resources.

    • "mex" — Compile and execute a MEX function. This option is available only when you use a GPU. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information about supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). If Parallel Computing Toolbox or a suitable GPU is not available, then the software returns an error.

    • "none" — Disable all acceleration.

    When you use the "auto" or "mex" option, the software can offer performance benefits at the expense of an increased initial run time. Subsequent calls to the function are typically faster. Use performance optimization when you call the function multiple times using different input data.

    When Acceleration is "mex", the software generates and executes a MEX function based on the model and parameters you specify in the function call. A single model can have several associated MEX functions at one time. Clearing the model variable also clears any MEX functions associated with that model.

    When Acceleration is "auto", the software does not generate a MEX function.

    The "mex" option is available only when you use a GPU. You must have a C/C++ compiler installed and the GPU Coder™ Interface for Deep Learning support package. Install the support package using the Add-On Explorer in MATLAB®. For setup instructions, see MEX Setup (GPU Coder). GPU Coder is not required.

    MATLAB Compiler™ software does not support compiling models when you use the "mex" option.

    Hardware resource, specified as one of these values:

    • "auto" — Use a GPU if one is available. Otherwise, use the CPU.

    • "gpu" — Use the GPU. Using a GPU requires a Parallel Computing Toolbox license and a supported GPU device. For information about supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). If Parallel Computing Toolbox or a suitable GPU is not available, then the software returns an error.

    • "cpu" — Use the CPU.

    Output Arguments

    collapse all

    Predicted classes, returned as a categorical array.

    Version History

    Introduced in R2023b