Main Content

Detect Defects Using Tiled Training of EfficientAD Anomaly Detector

Since R2024b

This example shows how to detect and localize industrial production defects on chewing gum images by training an EfficientAD anomaly detection network on tiled images.

The EfficientAD model [1] is a one-class anomaly detector that uses features extracted from a lightweight convolutional neural network (CNN) to distinguish normal and anomalous images in real-time. You train the model using only normal (non-anomalous) images. To detect local anomalies, the detector operates on the principle of a student-teacher model, where the teacher is a pretrained deep neural network that understands normal data behavior, and the student is a smaller network that learns to predict the features of normal images output by a teacher model. At inference time, both the student and teacher evaluate image data, and the model detects anomalies using the discrepancy between the student and teacher when the student fails to predict the anomaly features. In addition to structural anomalies, EfficientAD can detect logical anomalies that result from incorrect object orientation by using an autoencoder that is tuned to detect global anomalies. The high accuracy, low latency, and high throughput of EfficientAD enables you to detect and localize defects in large-scale applications that require a real-time or near-real-time response.

In this example, you train an EfficientAD network on tiled normal (non-anomalous) images and perform inference on full-size images. During training, you split the images in the data set into full-resolution segments, or tiles. This process preserves the full spatial resolution in the trained detector and is ideal for the detection of small local defects when the autoencoder is off. Evaluate the results using anomaly detection metrics and visual inspection of anomaly maps.

Download Pretrained EfficientAD Detector

By default, this example downloads a pretrained version of the EfficientAD anomaly detector using the helper function downloadTrainedNetwork. The function is attached to this example as a supporting file. You can use the pretrained network to run the entire example without waiting for training to complete. The fully convolutional architecture of EfficientAD enables tiled training and full-size inference without the need for multiple inference calls per image.

trainedGumDefectDetectorNet_url = "https://ssd.mathworks.com/supportfiles/" + ...
     "vision/data/trainedGumDefectDetectorEfficientAD.zip";
downloadTrainedNetwork(trainedGumDefectDetectorNet_url,pwd)
load("trainedGumDefectDetectorEfficientAD.mat")

Download VisA Data Set

Load the Visual Anomaly (VisA) data set of 10,821 high-resolution color images (9,621 normal and 1,200 anomalous samples) covering 12 different object subsets in 3 domains [2]. Several data subsets contain industrial production images of chewing gum, cashews, fryums, and other food. The anomalous images in the chewing gum test set contain surface defects such as scratches, dents, color spots as well as structural defects such as broken edges.

This example uses the chewing gum data subset. This data subset contains train and test folders, which include the normal training images, and the normal and anomalous test images, respectively.

Specify dataDir as the location of the data set. Download the data set using the downloadVisAData helper function. This function, which is attached to the example as a supporting file, downloads a ZIP file and extracts the data.

dataDir = fullfile(tempdir,"VisA");
downloadVisAData(dataDir)

Localize Defects in Image

Read a sample anomalous image with a "bad" label from the data set.

sampleImage = imread(fullfile(dataDir,"VisA", ...
    "chewinggum","test","bad","000.JPG"));

Visualize the localization of defects by displaying the original chewing gum image with the overlaid predicted per-pixel anomaly score map. Use the anomalyMap function to generate the anomaly score heatmap for the sample image. Display the image, with the heatmap overlaid by using the anomalyMapOverlay function.

anomalyHeatMap = anomalyMap(detector,sampleImage);
heatMapImage = anomalyMapOverlay(sampleImage,anomalyHeatMap);
montage({sampleImage,heatMapImage})
title("Heatmap of Anomalous Image")

Prepare Data for Training

Create imageDatastore objects that hold the training and test sets, from the train and test folders of the downloaded VisA data set.

exts = {".jpg",".png",".tif"};
dsTrain = imageDatastore(fullfile(dataDir,"VisA","chewinggum","train"),IncludeSubfolders=true,LabelSource="foldernames");
summary(dsTrain.Labels)
453×1 categorical

     good             453 
     <undefined>        0 
normalLabelStr = "good";
dsTest = imageDatastore(fullfile(dataDir,"VisA","chewinggum","test"),IncludeSubfolders=true,LabelSource="foldernames");
summary(dsTest.Labels)
150×1 categorical

     bad              100 
     good              50 
     <undefined>        0 

Display images of a normal piece of chewing gum and a defective piece of chewing gum from the test data set.

badImage = find(dsTest.Labels=="bad",1);
badImage = read(subset(dsTest,badImage));
normalImage = find(dsTest.Labels=="good",1);
normalImage = read(subset(dsTest,normalImage));
montage({normalImage,badImage})
title("Test Chewing Gum Images Without (Left) and With (Right) Defects")

Partition Data into Calibration and Test Sets

Prior to partitioning the data, set the global random state to "default" to ensure reproducibility.

rng("default")

Use a calibration set to determine the threshold for the classifier. Using separate calibration and test sets avoids information leaking from the test set into the design of the classifier. The classifier labels images with anomaly scores above the threshold as anomalous.

To establish a suitable threshold for the classifier, allocate 50% of the original test set as the calibration set dsCal, which has equal numbers of normal and anomalous images.

[dsCal,dsTest] = splitEachLabel(dsTest,0.5,"randomized");

Prepare Data For Training

Split Data into Masked Tiles

Create a blockedImage object from the images in the training set. The blockedImage object represents an image as a collection of smaller tiles, or blocks, of size 256-by-256 pixels. For each block, create a logical mask, in which true (or 1) pixels represent regions of interest that contain meaningful normal data, by specifying the maskObjectsInNormalImages helper function handle to the apply function as the processing function input. Save the blocked image masks normalMasks to the location of the mask folder you create trainNormalMaskDir.

trainNormalMaskDir = fullfile(dataDir,"EAD","chewinggum","train",normalLabelStr+"_masks");
blockSize = [256 256];
if ~isfolder(trainNormalMaskDir)
    bims = blockedImage(dsTrain.Files);

    % Find a global threshold from one normal image to segment
    % chewing gum images. Images in this data set are mostly uniformly
    % illuminated so the same threshold can be used for all images.
    im = gather(bims(1));
    T = graythresh(im);

    normalMasks = apply(bims,@(x)maskObjectsInNormalImages(x,T), ...
        BlockSize=blockSize,OutputLocation=trainNormalMaskDir, ...
        Adapter=images.blocked.PNGBlocks,UseParallel=false);

    save(fullfile(trainNormalMaskDir,normalLabelStr+"_masks.mat"),"normalMasks")

else
    load(fullfile(trainNormalMaskDir,normalLabelStr+"_masks.mat"),"normalMasks")
end

Select Training Tiles and Write Data to Disk

To accelerate training speed, create an on-disk representation of the blocked image mask data. Select blocked tiles for training that include a normal image of an object with meaningful data by using the selectBlockLocations function. To help the detector learn the background features also present in anomalous images, specify the InclusionThreshold argument as 0.

Write image tiles to disk using the writeall function, and display sample tiles. If you have a Parallel Computing Toolbox™ license, specify the UseParallel argument as true to reduce processing time.

normalBlocksDir = fullfile(dataDir,"EAD","chewinggum","train",normalLabelStr);
if ~isfolder(normalBlocksDir)
    bims = blockedImage(dsTrain.Files);

    % Select blocks based on the masks computed in the previous step
    bls = selectBlockLocations(bims,Levels=1,BlockSize=blockSize, ...
        Mask=normalMasks,ExcludeIncompleteBlocks=true,BlockOffsets=blockSize, ...
        InclusionThreshold=0,UseParallel=false);
    bimds = blockedImageDatastore(bims,BlockLocationSet=bls);

    % Write all tiles to disk
    writeall(bimds,normalBlocksDir,UseParallel=false)

    % Display a few blocks
    bimds.ReadSize = 10;
    blocks = read(bimds);
    figure
    montage(blocks,BorderSize=5,BackgroundColor="b")
end

Create Training, Test, and Calibration Datastores

Create an ImageDatastore object that contains the training tiles.

dsTrainTiles = imageDatastore(normalBlocksDir,IncludeSubfolders=true,FileExtensions=".png",LabelSource="foldernames");

Add one-hot encoded labels to the training data using the transform function. Specify the transform function using the addLabelData helper function.

dsTrainTiles = transform(dsTrainTiles,@addLabelData,IncludeInfo=true);

Similarly, add one-hot encoded labels to the calibration and test data using the transform function and the addLabelData helper function.

dsCal = transform(dsCal,@addLabelData,IncludeInfo=true);
dsTest = transform(dsTest,@addLabelData,IncludeInfo=true);

Define EfficientAD Object Detector Network Architecture

Create the EfficientAD anomaly detector using the efficientADAnomalyDetector object. Specify the UseGlobalAnomalyMap argument as false to turn off the autoencoder and use only the student-teacher model. This enables the model to generalize learned features from training on image tiles to the full-resolution image at inference time.

untrainedDetector = efficientADAnomalyDetector(Network="pdn-small",UseGlobalAnomalyMap=false);

Specify Training Options and Train Detector

To train the detector, set the doTraining variable to true.

Train the EfficientAD anomaly detector by using the trainEfficientADAnomalyDetector function. Specify the MapNormalizationDataRatio argument as 0.2 to compute percentile normalization statistics for the student model for 20% of the training data. For the rest of the training data, the autoencoder is turned off and the detector does not update its weights or perform normalization.

Specify network training options using the trainingOptions (Deep Learning Toolbox) function. Specify the Metrics argument as the aucMetric (Deep Learning Toolbox) function handle to monitor the AUC metric, or the area under receiver operating characteristic (ROC) curve, during training. Specify the MiniBatchSize name-value argument as 4. If your computing device memory is limited, decrease the mini-batch size to 1 to prevent out-of-memory errors.

Train on one or more GPUs, if they are available. Using a GPU requires a Parallel Computing Toolbox license and a CUDA® enabled NVIDIA® GPU. For more information, see GPU Computing Requirements (Parallel Computing Toolbox). If you have a Parallel Computing Toolbox license, specify the PreprocessingEnvironment argument as "parallel" to accelerate training speed.

doTraining = false;
if doTraining
    maximumEpochs=20;
    options = trainingOptions("adam", ...
        InitialLearnRate=1e-4, ...
        L2Regularization=1e-5, ...
        LearnRateSchedule="piecewise", ...
        LearnRateDropPeriod=floor(0.95*maximumEpochs), ...
        LearnRateDropFactor=0.1, ...
        MaxEpochs=maximumEpochs, ...
        VerboseFrequency=2, ...
        MiniBatchSize=4, ...
        Shuffle="every-epoch", ...
        ValidationData=dsCal, ...
        ValidationPatience=Inf, ...
        OutputNetwork="best-validation", ...
        Metrics=aucMetric(Name="auc"), ...
        ObjectiveMetricName="auc", ...
        ResetInputNormalization=true,...
        Plots="training-progress", ...
        PreprocessingEnvironment="background");
    detector = trainEfficientADAnomalyDetector(dsTrainTiles,untrainedDetector,options,MapNormalizationDataRatio=0.2);
    modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss"));
    save(string(tempdir) + filesep + "trainedGumDefectDetectorEfficientAD_" + modelDateTime + ".mat", ...
        "detector")
end

Set Anomaly Threshold

During semi-supervised anomaly detection, you must choose an anomaly score threshold for separating normal images from anomalous images. Select an anomaly score threshold for the anomaly detector, which classifies images based on whether their scores are above or below the threshold value. This example uses the calibration data set defined in the Partition Data into Calibration and Test Sets section, which contains both normal and anomalous images to select the threshold.

Obtain the maximum anomaly score and ground truth label for each image in the calibration set by using the predict object function. Specify the MiniBatchSize argument value as 4, or another small number. Note that the value of this argument must be the same as the mini-batch size used for training in the Specify Training Options and Train Detector section. Because the predict function evaluates the calibration images at full-size resolution, you must specify a small mini-batch size to prevent out-of-memory errors.

scoresCal = predict(detector,dsCal,MiniBatchSize=4);
labelsCal = dsCal.UnderlyingDatastores{1}.Labels~="good";

Plot a histogram of the maximum anomaly scores for the normal and anomalous classes. The distributions are well-separated by the model-predicted anomaly score.

numBins = 20;
[~,edges] = histcounts(scoresCal,numBins);
figure
hold on
hNormal = histogram(scoresCal(labelsCal==0),edges);
hAnomaly = histogram(scoresCal(labelsCal==1),edges);
hold off
legend([hNormal,hAnomaly],"Normal","Anomaly")
xlabel("Max Anomaly Score")
ylabel("Counts")

Calculate the optimal anomaly threshold by using the anomalyThreshold function. Specify the first two input arguments as the ground truth labels, labels, and predicted anomaly scores scores for the calibration data set. Specify the third input argument as true because true positive anomaly images have a labels value of true. The anomalyThreshold function returns the optimal threshold value as a scalar and the receiver operating characteristic (ROC) curve for the detector as an rocmetrics (Deep Learning Toolbox) object.

[thresh,roc] = anomalyThreshold(labelsCal,scoresCal,true);

Set the Threshold property of the anomaly detector to the optimal threshold value.

detector.Threshold = thresh;

Plot the ROC curve by using the plot (Deep Learning Toolbox) object function of rocmetrics. The ROC curve illustrates the performance of the classifier for a range of possible threshold values. Each point on the ROC curve represents the false positive rate (x-coordinate) and true positive rate (y-coordinate) when you classify the calibration set images using a different threshold value. The solid blue line represents the ROC curve. The area under the ROC curve (AUC) metric indicates classifier performance, with a perfect classifier corresponding to the maximum ROC AUC 1.0.

plot(roc)
title("ROC AUC: " + roc.AUC)

Evaluate Classification Model

Classify each image in the test set as either normal or anomalous by using the classify object function. Specify the MiniBatchSize argument as 1. Because the classify function evaluates the test images at full-size resolution, you must specify a small mini-batch size to prevent out-of-memory errors.

testSetPredictedLabels = classify(detector,dsTest,MiniBatchSize=1);
testSetPredictedLabels = testSetPredictedLabels';

Get the ground truth labels of each test image.

testSetGTLabels = dsTest.UnderlyingDatastores{1}.Labels ~= "good";

Evaluate the anomaly detector by calculating performance metrics using the evaluateAnomalyDetection function. The function calculates several metrics that evaluate the accuracy, precision, sensitivity, and specificity of the detector for the test data set.

metrics = evaluateAnomalyDetection(testSetPredictedLabels,testSetGTLabels,1);
Evaluating anomaly detection results
------------------------------------
* Finalizing... Done.
* Data set metrics:

    GlobalAccuracy    MeanAccuracy    Precision    Recall    Specificity    F1Score    FalsePositiveRate    FalseNegativeRate
    ______________    ____________    _________    ______    ___________    _______    _________________    _________________

       0.90667            0.87         0.89091      0.98        0.76        0.93333          0.24                 0.02       

The ConfusionMatrix property of metrics contains the confusion matrix for the test set. Extract the confusion matrix and display a confusion plot. The classification model in this example is very accurate and predicts a small percentage of false positives and false negatives.

M = metrics.ConfusionMatrix{:,:};
confusionchart(M,["Normal","Anomaly"])
acc = sum(diag(M))/sum(M,"all");
title("Accuracy: " + acc)

Explain Classification Decisions

You can use the anomaly score heatmap that the anomaly detector predicts to explain why the detector classifies an image as normal or anomalous. This approach is useful for identifying patterns in false negatives and false positives. You can use these patterns to identify strategies for increasing class balancing of the training data or improving the network performance.

Normalize Anomaly Heat Map Display Range

To compare anomaly scores observed across the entire calibration set, including normal and anomalous images, normalize the anomaly score map display range. By using the same display range across images, you can compare images more easily than if you scale each image to its own minimum and maximum.

Define a normalized display range, displayRange, that reflects the range of anomaly scores observed across the entire calibration set, including normal and anomalous images.

minMapVal = inf;
maxMapVal = -inf;
reset(dsCal)
while hasdata(dsCal)
    data = read(dsCal);
    img = data{1};
    map = anomalyMap(detector,img);
    minMapVal = min(min(map,[],"all"),minMapVal);
    maxMapVal = max(max(map,[],"all"),maxMapVal);
end
displayRange = [minMapVal 0.7*maxMapVal];

View Heatmap of Anomalous Image

Select an image of a correctly classified anomaly. Display the image, with the heatmap overlaid by using the anomalyMapOverlay function.

testSetAnomalyLabels = testSetGTLabels;
testSetOutputLabels = testSetPredictedLabels;
idxTruePositive = find(testSetAnomalyLabels & testSetOutputLabels);
dsExample = subset(dsTest,idxTruePositive);
data = read(dsExample);
img = data{1};
map = anomalyMap(detector,img);
imshow(anomalyMapOverlay(img,map,MapRange=displayRange,Blend="equal"))

View Heatmap of Normal Image

Select and display an image of a correctly classified normal image, with the heatmap overlaid by using the anomalyMapOverlay function.

idxTrueNegative = find(~(testSetAnomalyLabels | testSetOutputLabels));
dsExample = subset(dsTest,idxTrueNegative);
data = read(dsExample);
img = data{1};
map = anomalyMap(detector,img);
imshow(anomalyMapOverlay(img,map,MapRange=displayRange,Blend="equal"))

View Heatmap of False Positive Image

False positives are images without defects, but that the network misclassifies as anomalous. Use the explanation from the EfficientAD model [1] to gain insight into the misclassifications.

Select a false positive image, and display it, with the heatmap overlaid by using the anomalyMapOverlay function. For this test image, anomalous scores are localized to image areas with uneaven lightning, so you might consider adjusting the image contrast during preprocessing, increasing the number of images used for training, or choosing a different threshold at the calibration step.

idxFalsePositive = find(~(testSetAnomalyLabels) & testSetOutputLabels);
if ~isempty(idxFalsePositive)
    dsExample = subset(dsTest,idxFalsePositive);
    data = read(dsExample);
    img = data{1};
    map = anomalyMap(detector,img);
    figure
    imshow(anomalyMapOverlay(img,map,MapRange=displayRange,Blend="equal"))
end

View Heatmap of False Negative Image

False negatives are images with defects, but that the network misclassifies as normal. Use the explanation from the EfficientAD model [1] to gain insights into the misclassifications.

Select a false negative image, and display it, with the heatmap overlaid by using the anomalyMapOverlay function. To decrease false negative results, consider adjusting the anomaly threshold or CompressionRatio of the detector.

idxFalseNegative = find(testSetAnomalyLabels & (~testSetOutputLabels));
if ~isempty(idxFalseNegative)
    dsExample = subset(dsTest,idxFalseNegative);
    data = read(dsExample);
    img = data{1};
    map = anomalyMap(detector,img);
    figure
    imshow(anomalyMapOverlay(img,map,MapRange=displayRange,Blend="equal"))
end

Supporting Functions

addLabelData

The addLabelData helper function creates a one-hot encoded representation of label information in data.

function [data,info] = addLabelData(data,info)
    classNames = [info.Label];
    onehotencoding = classNames ~= categorical("good");
    data = {data,onehotencoding};
end

maskObjectsInNormalImages

The maskObjectsInNormalImages helper function creates a logical mask from a blocked image, in which true (or 1) pixels represent ROIs with meaningful normal data.

function bout = maskObjectsInNormalImages(bim,T)
    im = im2gray(bim.Data);

    % Threshold image using global threshold
    BW = imbinarize(im,T);

    % Morphologically open mask with square-shaped
    % structuring element
    width = 3;
    se = strel("square",width);
    BW = imopen(BW,se);

    bout = BW;
end

References

[1] Batzner, Kilian, Lars Heckler, and Rebecca König. “EfficientAD: Accurate Visual Anomaly Detection at Millisecond-Level Latencies.” In 2024 IEEE/CVF Winter Conference on Applications of Computer Vision (WACV), 127–37. Waikoloa, HI, USA: IEEE, 2024. https://doi.org/10.1109/WACV57701.2024.00020.

[2] Zou, Yang, Jongheon Jeong, Latha Pemula, Dongqing Zhang, and Onkar Dabeer. “SPot-the-Difference Self-Supervised Pre-Training for Anomaly Detection and Segmentation.” arXiv, July 28, 2022. https://doi.org/10.48550/arXiv.2207.14315.

See Also

| | | | | (Deep Learning Toolbox) |

Related Topics