Main Content

Train BERT Document Classifier

Since R2023b

This example shows how to train a BERT neural network for document classification.

A Bidirectional Encoder Representations from Transformer (BERT) model is a transformer neural network that can be fine-tuned for natural language processing tasks such as document classification and sentiment analysis. The network uses attention layers to analyze text in context and capture long-range dependencies between words.

This example fine-tunes a pretrained BERT-Base neural network to predict the category of factory reports using text descriptions.

Load Training Data

Read the training data from the factoryReports CSV file. The file contains factory reports, including a text description and categorical label for each report.

filename = "factoryReports.csv";
data = readtable(filename,TextType="string");
head(data)
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

Convert the labels in the Category column of the table to categorical values and view the distribution of the classes in the data using a histogram.

data.Category = categorical(data.Category);
figure
histogram(data.Category)
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

View the number of classes.

classNames = categories(data.Category);
numClasses = numel(classNames)
numClasses = 4

Partition the data into a training set and a test set. Specify the holdout percentage as 10%.

cvp = cvpartition(data.Category,Holdout=0.1);
dataTrain = data(cvp.training,:);
dataTest = data(cvp.test,:);

Extract the text data and labels from the tables.

textDataTrain = dataTrain.Description;
textDataTest = dataTest.Description;
TTrain = dataTrain.Category;
TTest = dataTest.Category;

Load Pretrained BERT Document Classifier

Load a pretrained BERT-Base document classifier using the bertDocumentClassifier function. If the Text Analytics Toolbox™ Model for BERT-Base Network support package is not installed, then the function provides a link to the required support package in the Add-On Explorer. To install the support package, click the link, and then click Install.

mdl = bertDocumentClassifier(ClassNames=classNames)
mdl = 
  bertDocumentClassifier with properties:

       Network: [1×1 dlnetwork]
     Tokenizer: [1×1 bertTokenizer]
    ClassNames: ["Electronic Failure"    "Leak"    "Mechanical Failure"    "Software Failure"]

Specify Training Options

Specify the training options. Choosing among training options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager app.

  • Train using the Adam optimizer.

  • Train for eight epochs.

  • For fine-tuning, lower the learning rate. Train using a learning rate of 0.0001.

  • Shuffle the data every epoch.

  • Monitor the training progress in a plot and monitor the accuracy metric.

  • Disable the verbose output.

options = trainingOptions("adam", ...
    MaxEpochs=8, ...
    InitialLearnRate=1e-4, ...
    Shuffle="every-epoch", ...  
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

Train Neural Network

Train the neural network using the trainBERTDocumentClassifier function. By default, the trainBERTDocumentClassifier function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information about supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainBERTDocumentClassifier function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.

mdl = trainBERTDocumentClassifier(textDataTrain,TTrain,mdl,options);

Test Neural Network

Make predictions using the test data.

YTest = classify(mdl,textDataTest);

Visualize the predictions in a confusion matrix.

figure
confusionchart(TTest,YTest)

Calculate the classification accuracy of the test predictions.

accuracy = mean(TTest == YTest)
accuracy = 0.9375

Make Predictions Using New Data

Classify the event type of new factory reports. Create a string array containing the new factory reports.

strNew = [
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];
labelsNew = classify(mdl,strNew)
labelsNew = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

See Also

| | (Deep Learning Toolbox) | (Deep Learning Toolbox)

Related Topics