Main Content

Out-of-Distribution Detection for BERT Document Classifier

Since R2024b

This example shows how to detect out-of-distribution (OOD) data in a BERT document classifier.

OOD data detection is the process of identifying inputs to a deep neural network that might yield unreliable predictions. OOD data refers to data that is different from the data used to train the model, for example, data collected in a different way, under different conditions, or for a different task than the data on which the model was originally trained.

You can classify data as in-distribution (ID) or OOD by assigning confidence scores to the predictions of a network. You can then choose how you treat OOD data. For example, you can choose to reject the prediction of a neural network if the network detects OOD data.

In this example, you fine-tune a pretrained BERT classification model to predict the type of maintenance work done on traffic signals using text descriptions. You then construct a discriminator to classify the text descriptions as ID or OOD.

In this example, you fine-tune and use a pretrained BERT document classifier in five steps:

  1. Import and preprocess the data.

  2. Separate the ID and OOD data.

  3. Fine-tune a pretrained BERT model using the ID data.

  4. Create a BERT mini-batch queue.

  5. Construct and calibrate a distribution discriminator and compare the distribution scores of the ID and OOD data.

Import and Preprocess Data

This example uses a large data set that contains records of work completed by traffic signal technicians in the city of Austin, TX, United States [1]. This data set is a table containing approximately 36,000 reports with various attributes, including a plain text description in the JobDescription variable and a categorical label in the WorkNeeded variable.

Load the example data.

zipFile = matlab.internal.examples.downloadSupportFile("textanalytics","data/Traffic_Signal_Work_Orders.zip");
filepath = fileparts(zipFile);
dataFolder = fullfile(filepath,"Traffic_Signal_Work_Orders");
unzip(zipFile,dataFolder);
filename = "Traffic_Signal_Work_Orders.csv";
data = readtable(fullfile(dataFolder,filename),TextType="string", VariableNamingRule="preserve");
data.Properties.VariableNames = matlab.lang.makeValidName(data.Properties.VariableNames);
head(data)
     WorkOrderID       Status        AssetType        AssetID      LocationID               CreatedDate                       ModifiedDate                     SubmittedDate                       ClosedDate              FiscalYear        WorkType                                  WorkNeeded                                                           WorkTypeOther                                       WorkRequestedBy                                          JobDescription                                                                                              ProblemFound                                                                                                                                                                                 ActionTaken                                                                                                                        Follow_UpNeeded    ChildWorkOrder    ParentWorkOrder    IsFollow_Up      TMCIssueID      ServiceRequest_    DamageReport                             LocationName                             Latitude    Longitude               Location           
    ______________    ________    ________________    _______    ______________    ______________________________    ______________________________    ______________________________    ______________________________    __________    ________________    _______________________________________________________________    _____________________________________________________________________    _____________________________    _____________________________________________________________________________    ________________________________________________________________________________________________________________________________    __________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________    _______________    ______________    _______________    ___________    ______________    _______________    ____________    ______________________________________________________________    ________    _________    ______________________________

    "WRK17-001685"    "Closed"    "School Flasher"      NaN      <missing>         "08/19/2017 08:55:00 PM +0000"    "09/14/2017 06:27:00 PM +0000"    "08/19/2017 09:00:00 PM +0000"    "09/14/2017 06:27:00 PM +0000"       2017       "Scheduled Work"    "Call-Back (Test Monitors and Cabinets)"                           <missing>                                                                "Austin Transportation Staff"    "HAVE AUSTIN ENERGY TIE IN NEW SOURCE DROP OVERHEAD @ CIMA SERENA WB FLASHER"    "N/A."                                                                                                                              "AUSTIN ENERGY TECHNICIANS DISPATCHED TO LOCATION. AE TECHS COULD NOT DO WORK BECAUSE OF LACK OF METER ON POLE/SOURCE. AE TECHS SAID TO CONTACT "WORK MANAGMENT NORTH" 5125057179 FOR FURTHER ACTION. INFORMATION WILL BE RELAYED TO SUPERVISOR. "        "False"          <missing>          <missing>        <missing>     <missing>          <missing>          <missing>      <missing>                                                           NaN          NaN       <missing>                     
    "WRK17-001865"    "Closed"    "Signal"              317      "LOC16-001550"    "08/24/2017 03:28:00 PM +0000"    "09/14/2017 06:42:00 PM +0000"    "08/24/2017 03:56:00 PM +0000"    "09/14/2017 06:42:00 PM +0000"       2017       "Scheduled Work"    "Installation - Other"                                             <missing>                                                                "Austin Transportation Staff"    <missing>                                                                        "bad cable for nb in the conduits"                                                                                                  "pulled in 20 conductor cable for nb signals and peds . installed a new 332 cabinet , respliced all signals and peds for 2 way project ."                                                                                                                 "False"          <missing>          <missing>        <missing>     <missing>          <missing>          <missing>      "5TH ST / TRINITY ST"                                               NaN          NaN       "POINT (-97.739677 30.266132)"
    "WRK17-001875"    "Closed"    "Signal"              319      "LOC16-001560"    "08/24/2017 03:45:00 PM +0000"    "09/14/2017 06:54:00 PM +0000"    "08/24/2017 04:03:00 PM +0000"    "09/14/2017 06:54:00 PM +0000"       2017       "Scheduled Work"    "Installation - Other"                                             <missing>                                                                "MMC"                            "install wb standard and splice in signals and peds"                             <missing>                                                                                                                           "install wb mast arm, remove street light pole, splice signal cables and peds"                                                                                                                                                                            "False"          <missing>          <missing>        <missing>     <missing>          <missing>          <missing>      "5TH ST / RED RIVER ST"                                             NaN          NaN       "POINT (-97.737488 30.265535)"
    "WRK17-001890"    "Closed"    "School Flasher"      NaN      <missing>         "08/24/2017 08:23:00 PM +0000"    "08/24/2017 08:31:00 PM +0000"    "08/24/2017 08:31:00 PM +0000"    "08/28/2017 03:08:00 PM +0000"       2017       "Trouble Call"      "OtherDay-Call (Deliver Timing sheets to intersections and PM)"    "SOMMERS ELEMENTARY - NOT FLASHING↵↵SR #17-00242843↵#17-00244051↵"    "Austin Transportation Staff"    "SOMMERS ELEMENTARY - NOT FLASHING↵SR #17-00242843, #17-00244051"               "NO PROBLEMS FOUND AT SCHOOL FLASHERS.  BOTH PEDESTRIAN FLASHERS NEED SCHEDULE."                                                    "BOTH SCHOOL CLOCKS CHECKED FOR TIME, DATE, SCHEDULE, FLASHERS OPERATION AND COMMUNICATION.↵BOTH PEDESTRIAN FLASHER CLOCKS CHECKED FOR TIME, DATE, SCHEDULE, OPERATION, AND COMM.↵TIME, DATE AND SCHEDULE UPDATED IN PEDESTRIAN FLASHER CLOCKS."         "False"          <missing>          <missing>        <missing>     <missing>          <missing>          <missing>      <missing>                                                           NaN          NaN       <missing>                     
    "WRK17-003185"    "Closed"    "Signal"               25      "LOC16-000120"    "10/09/2017 07:46:00 PM +0000"    "01/23/2023 04:47:00 PM +0000"    "10/09/2017 07:49:00 PM +0000"    "10/10/2017 04:45:00 PM +0000"       2018       "Scheduled Work"    "Installation - Camera"                                            <missing>                                                                "MMC"                            "replace the avidia cctv with a pelco repaired unit"                             <missing>                                                                                                                           "replaced the avidia cctv with a repaired pelco task # 2423015000"                                                                                                                                                                                        "False"          <missing>          <missing>        <missing>     <missing>          <missing>          <missing>      "MARTIN LUTHER KING JR BLVD / CONGRESS AVE (MLK/Capitol Mall)"      NaN          NaN       "POINT (-97.738106 30.280687)"
    "WRK17-003430"    "Closed"    "Signal"              185      "LOC16-000915"    "10/18/2017 08:43:00 PM +0000"    "10/26/2017 07:30:00 PM +0000"    "10/18/2017 08:49:00 PM +0000"    "10/26/2017 07:30:00 PM +0000"       2018       "Trouble Call"      "Visibility Issue"                                                 <missing>                                                                "MMC"                            "Tree limbs blocking WB signal direction."                                       "Tree limbs blocking WB signal direction."                                                                                          "Cut limbs blocking WB signal direction to make visible for ongoing traffic."                                                                                                                                                                             "True"           <missing>          <missing>        <missing>     "TMC17-006530"     "17-00311041"      <missing>      "LAMAR BLVD / PANTHER TRL"                                          NaN          NaN       "POINT (-97.789284 30.23867)" 
    "WRK17-001895"    "Closed"    "Signal"              NaN      <missing>         "08/24/2017 08:32:00 PM +0000"    "08/24/2017 08:40:00 PM +0000"    "08/24/2017 08:40:00 PM +0000"    "08/28/2017 03:06:00 PM +0000"       2017       "Trouble Call"      "OtherDay-Call (Deliver Timing sheets to intersections and PM)"    "DOSS/MURCHISON COMBO WB NOT FLASHING"                                   "Austin Transportation Staff"    "DOSS/MURCHISON COMBO WB NOT FLASHING"                                           "WB FLASHER ON GREYSTONE  DOES NOT HAVE COMMUNICATION. CLOCK HAD NO SCHEDULE.  EB FLASHER ON N HILLS DR. HAS LIMBS OBSTRUCTION."    "DATE, TIME, SCHEDULE, AND FLASHER OPERATION CHECKED FOR ALL CLCOKS. WB CLOCK ON GREYSTONE PROGRAMMED WITH 2017/2018 SCHEDULE. LIMBS REMOVED FROM EB FLASHER ON N HILLS DR."                                                                              "False"          <missing>          <missing>        <missing>     <missing>          <missing>          <missing>      <missing>                                                           NaN          NaN       <missing>                     
    "WRK17-002010"    "Closed"    "Signal"              779      "LOC16-003835"    "08/29/2017 07:58:00 PM +0000"    "09/14/2017 07:04:00 PM +0000"    "08/30/2017 11:02:00 AM +0000"    "09/14/2017 07:04:00 PM +0000"       2017       "Trouble Call"      "Detection Failure"                                                <missing>                                                                "MMC"                            "fisheye camera turned"                                                          "gridsmart camera out of alignment"                                                                                                 "with assistance from the TMC - realigned camera and tightened"                                                                                                                                                                                           "False"          <missing>          <missing>        <missing>     <missing>          <missing>          <missing>      "MC KINNEY FALLS PKWY / WILLIAM CANNON DR"                          NaN          NaN       "POINT (-97.72583 30.163218)" 

The goal of this example is to classify maintenance visits by the label in the WorkNeeded column. To divide the data into classes, convert these labels to categorical.

data.WorkNeeded = categorical(data.WorkNeeded);

Remove data classified into categories that are rare using the removeRareCategories function, defined at the end of this example.

data = removeRareCategories(data);

Remove data with empty job description.

data(ismissing(data.JobDescription),:) = [];

Separate ID and OOD Data

The data set includes two types of work, scheduled work and trouble calls.

data.WorkType = categorical(data.WorkType);
figure
histogram(data.WorkType)

Figure contains an axes object. The axes object contains an object of type categoricalhistogram.

In this example, you train a document classifier on the JobDescription fields of the reports from work resulting from trouble calls. This data comprises the ID data. The reports resulting from scheduled work comprise the OOD data.

dataID = data(data.WorkType=="Trouble Call",:);
dataOOD = data(data.WorkType=="Scheduled Work",:);

Remove any now unused categories from both ID and OOD data.

dataID.WorkNeeded = removecats(dataID.WorkNeeded);
dataOOD.WorkNeeded = removecats(dataOOD.WorkNeeded);

Compare the JobDescription fields of both ID and OOD data using word clouds.

figure
tiledlayout("horizontal")
nexttile
wordcloud(dataID.JobDescription);
title("In-distribution")
nexttile
wordcloud(dataOOD.JobDescription);
title("Out-of-distribution")

Figure contains objects of type wordcloud. The chart of type wordcloud has title In-distribution. The chart of type wordcloud has title Out-of-distribution.

Prepare Data for Training

Next, partition the ID data into sets for training, validation, and testing. Partition the data into a training set containing 80% of the ID data, a validation set containing 10% of the ID data, and a test set containing the remaining 10% of the ID data. To partition the data, use the trainingPartitions function, attached to this example as a supporting file. To access this file, open the example as a live script.

numReports = size(dataID,1);
[idxTrain,idxValidation,idxTest] = trainingPartitions(numReports,[0.8 0.1 0.1]); % attached to this example as a supporting file

dataTrain = dataID(idxTrain,:);
dataValidation = dataID(idxValidation,:);
dataTest = dataID(idxTest,:);

classNames = categories(dataID.WorkNeeded);

To avoid having two copies of the ID data in memory, remove dataID.

clear("dataID");

Extract the text data and labels from the partitioned tables and the OOD data.

documentsTrain = dataTrain.JobDescription;
documentsValidation = dataValidation.JobDescription;
documentsTest = dataTest.JobDescription;
documentsOOD = dataOOD.JobDescription;

YTrain = dataTrain.WorkNeeded;
YValidation = dataValidation.WorkNeeded;
YTest = dataTest.WorkNeeded;
YOOD = dataOOD.WorkNeeded;

Load Pretrained BERT Document Classifier

Load a pretrained BERT-Tiny document classifier using the bertDocumentClassifier function. If the Text Analytics Toolbox™ Model for BERT-Tiny Network support package is not installed, then the function provides a link to the required support package in the Add-On Explorer. To install the support package, click the link, and then click Install.

mdl = bertDocumentClassifier(Model="tiny",ClassNames=classNames)
mdl = 
  bertDocumentClassifier with properties:

       Network: [1×1 dlnetwork]
     Tokenizer: [1×1 bertTokenizer]
    ClassNames: ["Communication Failure"    "Detection Failure"    "Knockdown"    "LED Out"    "Push Button Not Working"    "Signal Out or on Flash"    "Timing Issue"    "Visibility Issue"]

Specify Training Options

Specify the training options. Choosing among training options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager (Deep Learning Toolbox) app.

  • Train using the Adam optimizer.

  • Train for eight epochs.

  • For fine-tuning, lower the learning rate. Train using a learning rate of 0.0001.

  • Shuffle the data every epoch.

  • Monitor the training progress in a plot and monitor the accuracy metric.

  • Disable the verbose output.

options = trainingOptions("adam", ...
    MaxEpochs=8, ...
    InitialLearnRate=1e-4, ...
    ValidationData={documentsValidation,YValidation}, ...
    Shuffle="every-epoch", ...  
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

Train Neural Network

Train the neural network using the trainBERTDocumentClassifier function. By default, the trainBERTDocumentClassifier function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information about supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainBERTDocumentClassifier function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.

mdl = trainBERTDocumentClassifier(documentsTrain,YTrain,mdl,options);

Test Neural Network

Make predictions using the test data.

YPred = classify(mdl,documentsTest);

Compare the true and predicted labels.

figure
confusionchart(YTest,YPred)

Figure contains an object of type ConfusionMatrixChart.

Detect OOD Data

You can assign confidence scores to network predictions by computing a distribution confidence score for each observation. ID data usually has a higher confidence score than OOD data. You can then apply a threshold to the scores to determine whether an input is ID or OOD.

Create a discriminator that separates ID and OOD data by using the networkDistributionDiscriminator (Deep Learning Toolbox) function. The function returns a discriminator containing a threshold for separating data into ID and OOD using their distribution scores.

The networkDistributionDiscriminator function requires at least three input arguments:

  1. A dlnetwork object.

  2. Input data in the form of either a dlarray object or a minibatchqueue object.

  3. The algorithm used by the function, specified as BaselineDistributionDiscriminator (Deep Learning Toolbox), ODINDistributionDiscriminator (Deep Learning Toolbox), EnergyDistributionDiscriminator (Deep Learning Toolbox), or HBOSDistributionDiscriminator (Deep Learning Toolbox).

Create BERT Mini-Batch Queue

First, create a mini-batch queue for BERT using the bertMiniBatchQueue function defined in this example.

miniBatchSize = 128;
mbqTrain = bertMiniBatchQueue(mdl,documentsTrain,miniBatchSize);
mbqValidation = bertMiniBatchQueue(mdl,documentsValidation,miniBatchSize);
mbqTest = bertMiniBatchQueue(mdl,documentsTest,miniBatchSize);
mbqOOD = bertMiniBatchQueue(mdl,documentsOOD,miniBatchSize);

To create a mini-batch queue, first create a datastore that holds the input data for BERT. Then create a minibatchqueue object using the preprocessPredictors function to preprocess the data. The preprocessPredictors function is attached to this example as a supporting file. It truncates and pads sequences to be the same length, equal to the context size of the BERT tokenizer. It also ensures that the sequences always end with an end-of-sentence token.

function mbq = bertMiniBatchQueue(mdl,documents,miniBatchSize)
tokenizer = mdl.Tokenizer;

[inputID, segmentID] = encode(tokenizer, documents);
inputIDDS = arrayDatastore(inputID, OutputType="same");
segmentIDDS = arrayDatastore(segmentID, OutputType="same");
combinedDS = combine(inputIDDS, segmentIDDS);

mbq = minibatchqueue(combinedDS,3,... % 3 outputs: inputID, mask, segmentID
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@(inputID,segmentID) preprocessPredictors(inputID,segmentID,tokenizer), ...
    MiniBatchFormat=["CTB" "CTB" "CTB"], ...
    OutputEnvironment="auto");
end

Compute Confidence Scores to Detect OOD Data

Extract the dlnetwork object from the BERT model mdl to pass to the networkDistributionDiscriminator function.

net = mdl.Network;

Calibrate Discriminator

Create a distribution discriminator using the energy OOD discrimination algorithm. The energy method computes distribution confidence scores based on softmax scores. For more information, see Distribution Confidence Scores (Deep Learning Toolbox). Set the Temperature name-value argument to 1.

discriminator = networkDistributionDiscriminator(net,mbqTrain,mbqOOD,"energy",Temperature=1);

To ensure the discriminator is well calibrated, calculate the distribution confidence scores of training and validation data by passing the discriminator object to the distributionScores (Deep Learning Toolbox) function. Plot a histogram of the distribution scores using the plotDistributionScores function, defined at the end of this example.

If the distribution discriminator is well calibrated, then the histograms of the two data sets are similar. If the histograms do not look similar, increase or decrease the value of the Temperature hyperparameter.

scoresTrain = distributionScores(discriminator,mbqTrain);
scoresValidation = distributionScores(discriminator,mbqValidation);
figure
plotDistributionScores(discriminator,scoresTrain,scoresValidation,"Training Data","Validation Data")

Figure contains an axes object. The axes object with xlabel Distribution Confidence Scores, ylabel Frequency contains 3 objects of type histogram, constantline. These objects represent Training Data, Validation Data, Threshold.

Detect OOD Data

Once you are satisfied with your distribution discriminator, pass the discriminator object to the isInNetworkDistribution (Deep Learning Toolbox) function along with the test data, which is ID data, and OOD data. To assess the performance of the discriminator on this set of OOD data, calculate the true positive rate (TPR) and false positive rate (FPR).

tfOOD = isInNetworkDistribution(discriminator,mbqOOD);
tfID = isInNetworkDistribution(discriminator,mbqTest);
tpr = nnz(tfID)/numel(tfID)
tpr = 
0.7906
fpr = nnz(tfOOD)/numel(tfOOD)
fpr = 
0.0355

To calculate the distribution scores and distribution threshold of ID and OOD data according to the discriminator, pass the discriminator object to the distributionScores (Deep Learning Toolbox) function. Plot a histogram of the distribution scores using the plotDistributionScores function, defined at the end of this example.

scoresID = distributionScores(discriminator,mbqTest);
scoresOOD = distributionScores(discriminator,mbqOOD);
figure
plotDistributionScores(discriminator,scoresID,scoresOOD,"In-distribution scores","Out-of-distribution scores")

Figure contains an axes object. The axes object with xlabel Distribution Confidence Scores, ylabel Frequency contains 3 objects of type histogram, constantline. These objects represent In-distribution scores, Out-of-distribution scores, Threshold.

Helper Functions

The plotDistributionScores function takes as input a distribution discriminator object and distribution confidence scores for ID and OOD data. The function plots a histogram of the two confidence scores and overlays the distribution threshold.

function plotDistributionScores(discriminator,scoresID,scoresOOD,labelID,labelOOD)
hID = histogram(scoresID,Normalization="percentage");
hold on
hOOD = histogram(scoresOOD,Normalization="percentage");
xl = xlim;
hID.BinWidth = (xl(2)-xl(1))/25;
hOOD.BinWidth = (xl(2)-xl(1))/25;
xline(discriminator.Threshold)
l = legend([labelID labelOOD "Threshold"],Location="best");
title(l,discriminator.Method+" distribution discriminator")
xlabel("Distribution Confidence Scores")
ylabel("Frequency")
hold off
end

The removeRareCategories function removes data from rarely used data.workNeeded categories, as well as miscellaneous categories that do not share many features.

function commonData = removeRareCategories(data)
workNeededCategories = categories(data.WorkNeeded);
categoryFrequencies = countcats(data.WorkNeeded);

commonCategories = workNeededCategories(categoryFrequencies>500);
commonData = data(ismember(data.WorkNeeded,commonCategories),:);

otherCategories = commonCategories(contains(commonCategories,"Other"));
commonData = commonData(~ismember(commonData.WorkNeeded,otherCategories),:);

commonData.WorkNeeded = removecats(commonData.WorkNeeded);
end

References

[1] Traffic Signal Work Orders. City of Austin Open Data. Retrieved April 30, 2023, from https://data.austintexas.gov/Transportation-and-Mobility/Traffic-Signal-Work-Orders/hst3-hxcz.

See Also

| | | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox)

Related Topics