Deep Learning Code Generation on ARM for Fault Detection Using Wavelet Scattering and Recurrent Neural Networks
This example demonstrates code generation for acoustic-based machine fault detection using a wavelet scattering network paired with a recurrent neural network. This example uses MATLAB® Coder™, MATLAB Coder Interface for Deep Learning, and MATLAB Support Package for Raspberry Pi® Hardware to generate a standalone executable (.elf
) file on a Raspberry Pi that leverages the performance of the ARM® Compute Library. The input data consists of acoustic time-series recordings from air compressors and the output is the state of the mechanical machine predicted by the LSTM-based RNN network. This standalone executable on Raspberry Pi runs the streaming classifier on the input data received from MATLAB and transfers the computed scores for each label to MATLAB on the host. For more details on audio preprocessing and network training, refer to Fault Detection Using Wavelet Scattering and Recurrent Deep Networks.
Code generation for wavelet time scattering offers significant performance improvement. See Generate and Deploy Optimized Code for Wavelet Time Scattering on ARM Targets for more information.
This example follows these steps:
Prerequisites
MATLAB® Coder™
Embedded Coder®
Raspberry Pi hardware
ARM Compute Library version 20.02.1 (on the target ARM hardware)
Environment variables for the compilers and libraries. For setting up the environment variables, see Environment Variables (MATLAB Coder).
For a list of supported compilers and libraries, see Generate Code That Uses Third-Party Libraries (MATLAB Coder).
Prepare Input Data Set
Download the data set and unzip the data file in a folder where you have write permission. The recordings are stored as .wav
files in folders named for their respective state.
% Download AirCompressorDataset.zip component = "audio"; filename = "AirCompressorDataset/AirCompressorDataset.zip"; localfile = matlab.internal.examples.downloadSupportFile(component,filename); % Unzip the downloaded zip file to the downloadFolder downloadFolder = fileparts(localfile); if ~exist(fullfile(downloadFolder, "AirCompressorDataset"),"dir") unzip(localfile, downloadFolder) end % Create an audioDatastore object, dataStore, to manage the data dataStore = audioDatastore(downloadFolder,IncludeSubfolders=true,LabelSource="foldernames"); % Use countEachLabel to get the number of samples of each category in the data set countEachLabel(dataStore)
ans=8×2 table
Label Count
_________ _____
Bearing 225
Flywheel 225
Healthy 225
LIV 225
LOV 225
NRV 225
Piston 225
Riderbelt 225
For the classification of audio recordings, construct a wavelet scattering network to extract wavelet scattering coefficients and use them for classification. Each record has 50,000 samples sampled at 16 kHz. Construct a wavelet scattering network based on the data characteristics. Set the invariance scale to 0.5 seconds.
Fs = 16e3; windowLength = 5e4; IS = 0.5; sn = waveletScattering(SignalLength=windowLength,SamplingFrequency=Fs,... InvarianceScale=0.5);
With these network settings, there are 330 scattering paths and 25 time windows per audio record. This leads to a sixfold reduction in the size of the data for each record.
[~,npaths] = paths(sn); Ncfs = numCoefficients(sn); sum(npaths)
ans = 330
Ncfs
Ncfs = 25
Initialize signalToBeTested
to point to the shuffled dataStore
that you downloaded earlier. Pass signalToBeTested
to the faultDetect
function for classification.
rng default;
dataStore = shuffle(dataStore);
[InputFiles,~] = splitEachLabel(dataStore, 0.5);
signalToBeTested = readall(InputFiles);
Recognize Machine Fault Detection in MATLAB
The faultDetect
function reads the input audio samples, calculates the wavelet scattering features, and performs deep learning classification. For more information, enter type
faultDetect
at the command line.
type faultDetect
function out = faultDetect(in) %#codegen % Copyright 2022-2024 The MathWorks, Inc. persistent net; % persistent classNames; if isempty(net) net = coder.loadDeepLearningNetwork("faultDetectNetwork.mat"); end persistent sn; if isempty(sn) windowLength = 5e4; Fs = 16e3; IS = 0.5; sn = waveletScattering(SignalLength=windowLength,SamplingFrequency=Fs, ... InvarianceScale=IS); end S = sn.featureMatrix(in,"transform","log"); TestFeatures = S(2:330,1:25); %Remove the 0-th order scattering coefficients scores = predict(net,dlarray(single(TestFeatures),"CT")); % out = scores2label(scores,classNames); out = scores; end
Pass each audio input to faultDetect
, which extracts wavelet scattering coefficients. Pass the coefficients to the LSTM-based RNN network, which classifies and returns the output. Each output maps to eight health states retrieved per input audio. For details on the network creation, refer to Fault Detection Using Wavelet Scattering and Recurrent Deep Networks.
inputCount=1; numInputs=10; % Validate 10 audio input files load("faultDetectNetwork.mat"); while inputCount <= numInputs % Get a frame of audio data x = signalToBeTested{inputCount}; % Apply streaming classifier function outputLabel(inputCount) = scores2label(faultDetect(x),classNames); inputCount = inputCount + 1; end scatter(1:numInputs,outputLabel,140,"filled") xlabel("Audio Input"); ylabel("Machine Health Status"); title("Machine Health Status per Audio Input on Host")
Recognize Machine Fault Detection on Raspberry Pi Using PIL Workflow
This section demonstrates code generation and deployment of machine fault detection using wavelet scattering and RNNs on Raspberry Pi hardware. Use a processor-in-the-loop (PIL) workflow for deployment and profiling. For more information, see SIL/PIL Manager Verification Workflow (Embedded Coder).
Create a code generation configuration object to generate the PIL function.
cfg = coder.config("lib","ecoder",true); cfg.VerificationMode = "PIL";
Create a deep learning configuration object (dlcfg
) for the "arm-compute"
library. Set the ARM compute version and architecture, and then attach dlcfg
to the coder configuration object.
dlcfg = coder.DeepLearningConfig("arm-compute"); dlcfg.ArmArchitecture = "armv7"; dlcfg.ArmComputeVersion = "20.02.1"; cfg.DeepLearningConfig = dlcfg ;
Use the MATLAB Support Package for Raspberry Pi Hardware function raspi
to create a connection to the Raspberry Pi. In this code, replace these keywords and uncomment code:
raspiname
with the host name of your Raspberry Piusername
with your user namepassword
with your password
if (~exist("r","var")) r = raspi("raspiname", "username", "password"); end hw = coder.hardware("Raspberry Pi"); cfg.Hardware = hw;
Specify the build directory and set the target language to C++.
buildDir = "~/remoteBuildDir"; cfg.Hardware.BuildDir = buildDir; cfg.TargetLang = "C++";
Enable profiling and generate the PIL code. A MEX file named faultDetect_pil
is generated in your current folder.
cfg.CodeExecutionProfiling = true; audioFrame = ones(windowLength,1); codegen -config cfg faultDetect -args {audioFrame} -silent;
Deploying code. This may take a few minutes. ### Connectivity configuration for function 'faultDetect': 'Raspberry Pi'
Call the generated PIL function from MATLAB to get the detected outputs and the execution time.
inputCount=1; numInputs=10; %Validate 10 audio input files load("faultDetectNetwork.mat"); while inputCount <= numInputs % Get a frame of audio data x = signalToBeTested{inputCount}; % Apply streaming classifier function outputLabel(inputCount) = scores2label(faultDetect_pil(x),classNames); inputCount = inputCount + 1; end
### Starting application: 'codegen\lib\faultDetect\pil\faultDetect.elf' To terminate execution: clear faultDetect_pil ### Launching application faultDetect.elf... Execution profiling data is available for viewing. Open Simulation Data Inspector. Execution profiling report will be available after termination.
scatter(1:numInputs,outputLable,140,"filled") xlabel("Audio Input") ylabel("Machine Health Status") title("Machine Health Status per Audio Input on Raspberry Pi")
Terminate the PIL execution.
clear faultDetect_pil;
### Host application produced the following standard output (stdout) and standard error (stderr) messages: Execution profiling report: coder.profile.show(getCoderExecutionProfile('faultDetect'))
Generate an execution profile report to evaluate execution time.
executionProfile = getCoderExecutionProfile("faultDetect"); report(executionProfile, ... "Units","Seconds", ... "ScaleFactor","1e-03", ... "NumericFormat","%0.4f");
Summary
In this example, you use the wavelet scattering transform with a simple recurrent network to classify faults in an air compressor. The scattering transform allowed you to extract robust features for the learning problem. Additionally, the data reduction you achieved along the time dimension of the data by using the wavelet scattering transform was critical to create a computationally feasible problem for the recurrent network.
References
[1] Verma, Nishchal K., Rahul Kumar Sevakula, Sonal Dixit, and Al Salour. “Intelligent Condition Based Monitoring Using Acoustic Signals for Air Compressors.” IEEE Transactions on Reliability 65, no. 1 (March 2016): 291–309. https://doi.org/10.1109/TR.2015.2459684.
Copyright 2022-2024, The MathWorks, Inc.
See Also
Related Examples
- Air Compressor Fault Detection Using Wavelet Scattering
- Deploy Signal Classifier Using Wavelets and Deep Learning on Raspberry Pi
- Fault Detection Using Wavelet Scattering and Recurrent Deep Networks
- Generate and Deploy Optimized Code for Wavelet Time Scattering on ARM Targets