Main Content

Anomaly Detection Using Convolutional Autoencoder with Wavelet Scattering Sequences

This example shows how to use wavelet scattering sequences with the deepSignalAnomaly detector to detect anomalies in acoustic data. The data in this example are acoustic recordings of a normally functioning air compressor and one with bearing faults. The example compares the results using scattering sequences against those obtained with raw data. In doing so, the example shows that the principles of data-centric AI are powerful considerations when applying deep learning to signal data.

Dataset Description and Download

The dataset consists of acoustic recordings collected on a single-stage reciprocating-type air compressor [1]. The data are sampled at 16 kHz. Specifications of the air compressor are as follows:

  • Air Pressure Range: 0-500 lb/m2, 0-35 Kg/cm2

  • Induction Motor: 5HP, 415V, 5Am, 50 Hz, 1440rpm

  • Pressure Switch: Type PR-15, Range 100-213 PSI

Each recording represents one of 8 states which includes the healthy state and 7 faulty states. The 7 faulty states are:

  • Leakage inlet valve (LIV) fault

  • Leakage outlet valve (LOV) fault

  • Non-return valve (NRV) fault

  • Piston ring fault

  • Flywheel fault

  • Rider belt fault

  • Bearing fault

In this example the focus is only on normal (healthy) air compressors and those with bearing faults.

Download the dataset and unzip the data file in a folder where you have write permission. This example assumes you are downloading the data in the temporary directory designated as tempdir in MATLAB®. If you choose to use a different folder, substitute that folder for tempdir in the following. The recordings are stored as .wav files in folders named for their respective condition.

url = 'https://www.mathworks.com/supportfiles/audio/AirCompressorDataset/AirCompressorDataset.zip';
downloadFolder = fullfile(tempdir,'AirCompressorDataSet');
if ~exist(fullfile(tempdir,'AirCompressorDataSet'),'dir') 
    loc = websave(downloadFolder,url); 
    unzip(loc,fullfile(tempdir,'AirCompressorDataSet')) 
 end

Because the data are stored in .wav files, use an audioDatastore to manage data access. Each subfolder contains only recordings of the designated class. Use the folder names as the class labels.

rng("default")
ads = audioDatastore(downloadFolder,IncludeSubfolders=true,...
    LabelSource="foldernames"); 

This example focuses only on differentiating recordings obtained in a normally functioning air compressor from those with bearing faults. Create a training set consisting of 50% of the recordings, a validation set consisting of 20%, and a hold-out test set consisting of 30%.

[adsTrain,adsValidation,adsTest] = splitEachLabel(ads,0.5,0.2,0.3,...
    Include=["Healthy","Bearing"]);

Show the number of examples in each category for the three sets. There are 113 examples in both the normal and faulty categories for the training data, 45 in each class for the validation data, and 67 in each class for the test data. Note that the training and validation data are only necessary for the normal data. Subsequently, the recordings with bearing faults in the training, validation, and test sets will be combined because the anomaly detector does not see those recordings during training and validation.

C = categories(adsTrain.Labels);
categoriesToRemove = C(~ismember(C,unique(adsTrain.Labels)));
uniqueLabels = unique(removecats(adsTrain.Labels,categoriesToRemove));
tblTrain = countEachLabel(adsTrain);
tblValidation = countEachLabel(adsValidation);
tblTest = countEachLabel(adsTest);
H = bar(uniqueLabels,[tblTrain.Count, tblValidation.Count, tblTest.Count],'stacked');
ax = gca;
ax.YTick = [113 113+45 113+45+67];
ax.YLabel.String = "Cumulative Number of Examples";
legend(H,["Training","Validation","Test"],Location="NorthEastOutside",FontSize=12)

Figure contains an axes object. The axes object with ylabel Cumulative Number of Examples contains 3 objects of type bar. These objects represent Training, Validation, Test.

Plot one randomly selected signal and its power spectrum from each class in the training set for illustration purposes.

tiledlayout(2,2)
for n = 1:numel(uniqueLabels)
    idx = find(adsTrain.Labels==uniqueLabels(n),1);
    [x,fs] = audioread(adsTrain.Files{idx});    
    t = (0:size(x,1)-1)/fs;
    nexttile
    plot(t,x);
    grid on
    xlabel("Time (seconds)")
    ylabel("Amplitude")
    title(string(uniqueLabels(n)));
    nexttile
    pspectrum(x,fs)
    title(string(uniqueLabels(n)));
end

Figure contains 4 axes objects. Axes object 1 with title Bearing, xlabel Time (seconds), ylabel Amplitude contains an object of type line. Axes object 2 with title Bearing, xlabel Frequency (kHz), ylabel Power Spectrum (dB) contains an object of type line. Axes object 3 with title Healthy, xlabel Time (seconds), ylabel Amplitude contains an object of type line. Axes object 4 with title Healthy, xlabel Frequency (kHz), ylabel Power Spectrum (dB) contains an object of type line.

Wavelet Scattering Network

Construct the wavelet scattering network. Set the invariance scale to 0.3 seconds for a sample rate of 16 kHz which corresponds to 4800 samples. Set the base-2 logarithm of the oversampling factor to 1 to reduce the amount of downsampling by a factor of 2. These settings result in 98 coefficients per scattering path and 324 paths in total.

N = 5e4;
sn = waveletScattering(SignalLength=N,SamplingFrequency=fs,...
    InvarianceScale=0.3,OversamplingFactor=1); 
numCoefficients(sn)
ans = 98
[~,npaths] = paths(sn);
sum(npaths)
ans = 324

Set up the batch processing for obtaining the wavelet scattering coefficients for the training, validation, and test sets. The processing of the scattering coefficients is done by the helper function, helperbatchscatfeatures, included at the end of this example. Each signal is standardized to have zero mean and unit standard deviation before obtaining the scattering coefficients. Additionally, the data are cast to single precision. In this example, an NVIDIA® Titan V GPU is used to speed up the computation of the scattering coefficients. If you do not access to a GPU with sufficient compute capability, you can set useGPU to false.

useGPU = true; 
batchsize = 64;
scTrain = [];
reset(adsTrain)
while hasdata(adsTrain)
    sc = helperbatchscatfeatures(adsTrain,sn,N,batchsize,useGPU);
    scTrain = cat(3,scTrain,sc);
end

Repeat the same process for the validation and test sets.

scValidation = [];
reset(adsValidation)
while hasdata(adsValidation)
   sc = helperbatchscatfeatures(adsValidation,sn,N,batchsize,useGPU);
   scValidation = cat(3,scValidation,sc); 
end
scTest = [];
reset(adsTest)
while hasdata(adsTest)
   sc = helperbatchscatfeatures(adsTest,sn,N,batchsize,useGPU);
   scTest = cat(3,scTest,sc); 
end

For each of the scattering coefficients remove the 0th order coefficients and put the scattering coefficients in cell arrays for use in the deepSignalAnomalyDetector. This reduces the number of scattering paths (channels) to 323.

trainFeatures = scTrain(2:end,:,:);
trainFeatures = squeeze(num2cell(trainFeatures,[1 2])); 
trainLabels = adsTrain.Labels;
validationFeatures = scValidation(2:end,:,:);
validationFeatures = squeeze(num2cell(validationFeatures,[1 2]));
validationLabels = adsValidation.Labels;
testFeatures = scTest(2:end,:,:);
testFeatures = squeeze(num2cell(testFeatures,[1 2]));
testLabels = adsTest.Labels;

To prepare for training and testing, subset the training, validation, and test sets into the two classes. Note that normal ("Healthy") data is used to train the detector and for validation.

trainNormal = trainFeatures(trainLabels=="Healthy");
trainFaulty = trainFeatures(trainLabels=="Bearing");
validationNormal = validationFeatures(validationLabels=="Healthy");
validationFaulty = validationFeatures(validationLabels=="Bearing");
testNormal = testFeatures(testLabels=="Healthy");
testFaulty = testFeatures(testLabels=="Bearing");
testLabels = removecats(testLabels,categoriesToRemove);

Transpose the scattering coefficients into time-by-channel for input into the deepSignalAnomalyDetector.

trainNormal = cellfun(@transpose,trainNormal,UniformOutput=false);
validationNormal = cellfun(@transpose,validationNormal,UniformOutput=false);
testNormal = cellfun(@transpose,testNormal,UniformOutput=false);
trainFaulty = cellfun(@transpose,trainFaulty,UniformOutput=false);
validationFaulty = cellfun(@transpose,validationFaulty,UniformOutput=false);
testFaulty = cellfun(@transpose,testFaulty,UniformOutput=false);

Combine the faulty training, validation, and test data into one dataset consisting of 225 examples labeled testFaulty.

testFaulty = cat(1,trainFaulty,validationFaulty,testFaulty);

Perform similar preparations for the raw data.

reset(adsTrain)
reset(adsValidation)
reset(adsTest)
trainSequences = readall(adsTrain);
validationSequences = readall(adsValidation);
testSequences = readall(adsTest);

Similar to what was done with the scattering sequences, standardize the time series prior to training and testing. Subset the standardized sequences into the normal and faulty classes.

trainSequences = cellfun(@(x)(x-mean(x))./std(x),trainSequences,UniformOutput=false);
validationSequences = cellfun(@(x)(x-mean(x))./std(x),validationSequences,UniformOutput=false);
testSequences = cellfun(@(x)(x-mean(x))./std(x),testSequences,UniformOutput=false);
normalTrainSequences = trainSequences(adsTrain.Labels == "Healthy");
normalValidationSequences = validationSequences(adsValidation.Labels=="Healthy");
normalTestSequences = testSequences(adsTest.Labels=="Healthy");
faultyTrainSequences = trainSequences(adsTrain.Labels == "Bearing");
faultyValidationSequences = trainSequences(adsValidation.Labels == "Bearing");
faultyTestSequences = testSequences(adsTest.Labels=="Bearing");

Combine the training, validation, and test sets for the faulty data into one set consisting of 225 recordings for model testing.

faultySequences = cat(1,faultyTrainSequences,faultyValidationSequences,faultyTestSequences);

Training

In this section, the deepSignalAnomalyDetector is trained using both the scattering sequences and the raw data. The training options use a suitable GPU if available. For the Titan V GPU used in this example, training requires about four minutes for each model. If you wish to skip the training, set trainModels to false and proceed to the Testing section of the example where you can load the pre-trained anomaly detectors.

Use the default deepSignalAnomalyDetector architecture, which is a convolutional autoencoder. Set the window length to use the entire time window in both cases. First, set up the deep signal anomaly detector for the scattering sequences. Here the number of channels is equal to one less than the number of scattering paths because the 0th order scattering coefficients have been removed.

trainModels = true;
if trainModels
    numChannels = size(trainNormal{1},2);
    dsadSCAT = deepSignalAnomalyDetector(numChannels,WindowLength="fullSignal");
end

Next, set up the anomaly detector for the raw data. Here the number of channels is equal to 1.

if trainModels
    numChannels = 1;
    dsadRAW = deepSignalAnomalyDetector(numChannels,WindowLength="fullSignal");
end

Set up the training options. Use an Adam optimizer, a minibatch size of 16, and shuffle the data each epoch. Run the training for 300 epochs. Because the validation data is different for the scattering sequences and the raw data, you need to define two separate trainingOptions objects. Those objects only differ in the included validation data sets.

if trainModels
    optsSCAT = trainingOptions("adam", MaxEpochs=300,MiniBatchSize=16, ...
        ValidationData={validationNormal,validationNormal},...
        Shuffle="every-epoch",...
        OutputNetwork = "best-validation-loss",...
        Verbose=false,Plots="training-progress"); 
    optsRAW = trainingOptions("adam",MaxEpochs=300,MiniBatchSize=16, ...
        ValidationData={normalValidationSequences,normalValidationSequences},... 
        OutputNetwork = "best-validation-loss",...
        Verbose=false,Plots="training-progress");
end

Train the detector using the scattering sequences.

if trainModels
    trainDetector(dsadSCAT,trainNormal,optsSCAT)
end

Repeat the training for raw data if trainModels is true.

if trainModels
    trainDetector(dsadRAW,normalTrainSequences,optsRAW)
end

Testing

If you elected to skip training by setting trainModels to false, you can load the pre-trained models with the following code.

if ~trainModels
    load dsadSCAT.mat %#ok<*UNRCH>
    load dsadRAW.mat
end

To assess the trained models, start by plotting the reconstruction loss distributions for the normal and faulty recordings. The following does this for the model trained with wavelet scattering sequences.

figure
figh = plotLossDistribution(dsadSCAT,testNormal,testFaulty);
figh.Children(1).String = ["Healthy","Bearing","Normal CDF","Faulty CDF"];
ax = gca;
ax.Title.String = "Reconstruction Loss -- Scattering Sequences";

Figure contains an axes object. The axes object with title Reconstruction Loss -- Scattering Sequences, xlabel Reconstruction Loss, ylabel CDF contains 5 objects of type histogram, line, constantline. These objects represent Healthy, Bearing, Normal CDF, Faulty CDF.

The probability histograms for the reconstruction losses for the normal and faulty classes demonstrate a very good separation. Determine the accuracy of the model on the held-out test set and plot the confusion chart.

First, create a categorical vector consisting of all the test set labels.

fullTestLabels = cat(1,trainLabels(trainLabels=="Bearing"),validationLabels(validationLabels=="Bearing"),testLabels(testLabels=="Bearing"),testLabels(testLabels=="Healthy"));
fullTestLabels = removecats(fullTestLabels,categoriesToRemove);
predNormalSCAT = detect(dsadSCAT,testNormal);
predBearingSCAT = detect(dsadSCAT,testFaulty);
predSCAT = cat(1,cell2mat(predBearingSCAT),cell2mat(predNormalSCAT));
predSCAT = categorical(predSCAT,[1 0],["Bearing" "Healthy"]);
cm = confusionchart(fullTestLabels,predSCAT);
cm.RowSummary = "row-normalized";
cm.ColumnSummary = "column-normalized";
cm.Title = "Wavelet Scattering Sequences with deepSignalAnomalyDetector";

Figure contains an object of type ConfusionMatrixChart. The chart of type ConfusionMatrixChart has title Wavelet Scattering Sequences with deepSignalAnomalyDetector.

In terms of precision (positive predictive value) and recall (sensitivity), the deepSignalAnomalyDetector paired with the wavelet scattering sequences performs quite well. You can use getModel with the deepSignalAnomalyDetector to see the convolutional architecture the detector trained and uses for detection.

SCATMdl = getModel(dsadSCAT);
SCATMdl.Layers
ans = 
  16×1 Layer array with layers:

     1   'sequenceinput'         Sequence Input               Sequence input with 323 dimensions
     2   'padding'               Function                     @(X)cat(3,X,zeros(size(X,1),size(X,2),ceil(size(X,3)/minLength)*minLength-size(X,3)))
     3   'conv1d_1'              1-D Convolution              32 8×323 convolutions with stride 2 and padding 'same'
     4   'relu_1'                ReLU                         ReLU
     5   'dropout_1'             Dropout                      20% dropout
     6   'conv1d_2'              1-D Convolution              32 8×32 convolutions with stride 2 and padding 'same'
     7   'relu_2'                ReLU                         ReLU
     8   'dropout_2'             Dropout                      20% dropout
     9   'transposed-conv1d_1'   1-D Transposed Convolution   1-D transposed convolution with 32 filters of size 8, stride 2, and cropping 'same'
    10   'relu_3'                ReLU                         ReLU
    11   'dropout_3'             Dropout                      20% dropout
    12   'transposed-conv1d_2'   1-D Transposed Convolution   1-D transposed convolution with 32 filters of size 8, stride 2, and cropping 'same'
    13   'relu_4'                ReLU                         ReLU
    14   'dropout_4'             Dropout                      20% dropout
    15   'fc'                    Fully Connected              323 fully connected layer
    16   'truncating'            Function                     @(X,Y)X(1:size(Y,1),1:size(Y,2),1:size(Y,3))

Plot the reconstruction loss distributions for the anomaly detector trained on the raw time series.

figure
figh = plotLossDistribution(dsadRAW,normalTestSequences,faultySequences);
figh.Children(1).String = ["Normal","Faulty","Normal CDF","Faulty CDF"];
ax = gca;
ax.Title.String = "Reconstruction Loss Distribution -- Raw Data";

Figure contains an axes object. The axes object with title Reconstruction Loss Distribution -- Raw Data, xlabel Reconstruction Loss, ylabel CDF contains 5 objects of type histogram, line, constantline. These objects represent Normal, Faulty, Normal CDF, Faulty CDF.

The reconstruction loss distributions for the two classes demonstrate a puzzling reversal in the case of the raw data. Specifically, the reconstruction loss distribution loss for the normal data is shifted to the right of the faulty-data distribution. The overall range of the loss is small indicating that there is not much reconstruction loss for either class. Given the threshold indicated on the plot, you can infer that this model will perform very poorly. In fact, it is likely to simply classify the vast majority of signals as normal.

You can confirm this by running the following code, but in this case you can already discern this information from the plot of the reconstruction loss distributions.

predHealthyRAW = detect(dsadRAW,normalTestSequences);
predBearingRAW = detect(dsadRAW,faultySequences);
predRAW = cat(1,cell2mat(predBearingRAW),cell2mat(predHealthyRAW));
predRAW = categorical(predRAW,[1 0],["Bearing" "Healthy"]);
cm = confusionchart(fullTestLabels,predRAW);
cm.RowSummary = "row-normalized";
cm.ColumnSummary = "column-normalized";
cm.Title = "Raw Sequences with deepSignalAnomalyDetector";

Figure contains an object of type ConfusionMatrixChart. The chart of type ConfusionMatrixChart has title Raw Sequences with deepSignalAnomalyDetector.

None of the 225 faulty signals has been correctly diagnosed to contain anomalies, while one normal signal has been falsely classified as faulty. One useful diagnostic to assess what has happened in using the deep signal anomaly detector on raw data is to use, plotAnomalies, to plot the anomalies along with the reconstruction loss for sample waveforms from the normal and faulty test sets.

idx = randi(numel(normalTestSequences));
plotAnomalies(dsadRAW,normalTestSequences{idx},PlotReconstruction=true)
title("Normal Sequence With Reconstruction Loss")

Figure contains an axes object. The axes object with title Normal Sequence With Reconstruction Loss, xlabel Samples, ylabel Channel 1 contains 4 objects of type line. One or more of the lines displays its values using only markers These objects represent Raw Signal, Anomalies, Reconstructed Signal.

Zoom in on the plot to see the reconstruction in more detail.

xlim([2.5e4 3e4])

Figure contains an axes object. The axes object with title Normal Sequence With Reconstruction Loss, xlabel Samples, ylabel Channel 1 contains 4 objects of type line. One or more of the lines displays its values using only markers These objects represent Raw Signal, Anomalies, Reconstructed Signal.

Next plot the reconstruction for a randomly selected sequence with a bearing fault. Zoom in to see the reconstruction in more detail.

figure
plotAnomalies(dsadRAW,faultySequences{idx},PlotReconstruction=true)
xlim([2.5e4 3e4])
title("Faulty Sequence With Reconstruction Loss")

Figure contains an axes object. The axes object with title Faulty Sequence With Reconstruction Loss, xlabel Samples, ylabel Channel 1 contains 4 objects of type line. One or more of the lines displays its values using only markers These objects represent Raw Signal, Anomalies, Reconstructed Signal.

You see the data has been reconstructed almost perfectly (without loss) in both cases. Ideally, the reconstruction loss would be small for the normal data and much larger for the faulty data. It is clear that the anomaly detector has not learned to differentiate normal from faulty signals in this configuration.

Summary

This example has demonstrated the deepSignalAnomalyDetector can perform remarkably better in certain applications if data-centric AI principles are used. In data-centric AI, the focus is often more on the data input to the network rather than the network architecture. Here, the use of wavelet scattering sequences greatly improved the ability of the anomaly detector to differentiate normal from faulty recordings. The key to the success of wavelet scattering in this application is that the scattering transform provides a robust time-frequency feature extractor.

References

[1] N. K. Verma, R. K. Sevakula, S. Dixit and A. Salour, "Intelligent Condition Based Monitoring Using Acoustic Signals for Air Compressors," in IEEE Transactions on Reliability, vol. 65, no. 1, pp. 291-309, March 2016, doi: 10.1109/TR.2015.2459684.

Appendix

Helper functions for the example.

function sc = helperbatchscatfeatures(ds,sn,N,batchsize,useGPU)
% This function is only intended to support examples in the Signal
% Processing and Wavelet Toolboxes. It may be changed or removed in a future
% release.
% Read batch of data from audio datastore
batch = helperReadBatch(ds,N,batchsize);
if useGPU
    batch = gpuArray(batch);
end
% Obtain scattering features. First standardize the batch.
batch = (batch-mean(batch))./std(batch,[],1);
S = sn.featureMatrix(batch);
gather(batch);
S = gather(S);
sc = S;
end

function batchout = helperReadBatch(ds,N,batchsize)
% This function is only intended to support examples in the Signal
% Processing and Wavelet Toolboxes. It may change or be removed in a future
% release.
kk = 1;
while(hasdata(ds)) && kk <= batchsize
    tmpRead = read(ds);
    batchout(:,kk) = cast(tmpRead(1:N),"single"); %#ok<AGROW>
    kk = kk+1;
end
end