Code Generation for LSTM Network That Classifies Text Data
This example shows how to generate generic C code for a pretrained long short-term memory (LSTM) network that classifies text data. This example generates a MEX function that makes predictions for each step of an input timeseries. The example demonstrates two approaches. The first approach uses a standard LSTM network. The second approach leverages the stateful behavior of the same LSTM network. This example uses textual descriptions of factory events that can be classified into one of these four categories: Electronic Failure, Leak, Mechanical Failure, and Software Failure. For more information about the pretrained LSTM network, see the Classify Text Data Using Deep Learning (Text Analytics Toolbox).
This example is supported on Mac®, Linux® and Windows® platforms and not supported for MATLAB Online.
Prepare Input
Load the wordEncoding
MAT-file. This MAT-file stores the words encoded as numerical indices. This encoding was performed during the training of the network. For more information, see Classify Text Data Using Deep Learning (Text Analytics Toolbox).
load("wordEncoding.mat");
Create a string array containing the new reports to classify the event type.
reportsNew = [ ... "Coolant is pooling underneath sorter." "Sorter blows fuses at start up." "There are some very loud rattling sounds coming from the assembler." "At times mechanical arrangement software freezes." "Mixer output is stuck."];
Tokenize the input string by using the preprocessText
function.
documentsNew = preprocessText(reportsNew);
Use the doc2sequence
(Text Analytics Toolbox) function to convert documents to sequences.
XNew = doc2sequence(enc,documentsNew); labels = categorical({'Electronic Failure', 'Leak', 'Mechanical Failure', 'Software Failure'});
The lstm_predict
Entry-Point Function
A sequence-to-sequence LSTM network enables you to make different predictions for each individual time step of a data sequence. The lstm_predict.m
entry-point function takes an input sequence and passes it to a trained LSTM network for prediction. Specifically, the function uses the LSTM network that is trained in the example Classify Text Data Using Deep Learning (Text Analytics Toolbox). The function loads the network object from the textClassifierNetwork.mat
file into a persistent variable and then performs prediction. On subsequent calls, the function reuses the persistent object.
type('lstm_predict.m')
function out = lstm_predict(in) %#codegen % Copyright 2020-2024 The MathWorks, Inc. dlIn = dlarray(in,'CT'); persistent dlnet; if isempty(dlnet) dlnet = coder.loadDeepLearningNetwork('textClassifierNetwork.mat'); end dlOut = predict(dlnet, dlIn); out = extractdata(dlOut); end
To display an interactive visualization of the network architecture and information about the network layers, use the analyzeNetwork
function.
Generate MEX
To generate code, create a code configuration object for a MEX target and set the target language to C. Use the coder.DeepLearningConfig
function to create a deep learning configuration object that does not depend on third-party libraries. Assign it to the DeepLearningConfig
property of the code configuration object.
cfg = coder.config('mex'); cfg.TargetLang = 'C'; cfg.IntegrityChecks = false; cfg.DeepLearningConfig = coder.DeepLearningConfig(TargetLibrary = 'none');
Use the coder.typeof
(MATLAB Coder) function to specify the type and size of the input argument to the entry-point function. In this example, the input is of single data type with a feature dimension value of 1 and a variable sequence length.
matrixInput = coder.typeof(single(0),[1 Inf],[false true]);
Generate a MEX function by running the codegen
(MATLAB Coder) command.
codegen -config cfg lstm_predict -args {matrixInput} -report
Code generation successful: View report
Run Generated MEX
Call lstm_predict_mex
on the first observation.
YPred1 = lstm_predict_mex(single(XNew{1}));
YPred1
contains the probabilities for the four classes. Find the predicted class by calculating the index of the maximum probability.
[~, maxIndex] = max(YPred1);
Associate the indices of max probability to the corresponding label. Display the classification. From the results, you can see that the network predicted the first event to be a Leak.
predictedLabels1 = labels(maxIndex); disp(predictedLabels1)
Leak
Generate MEX with Stateful LSTM
Instead of passing the entire timeseries to predict in one step, you can run prediction on an input by streaming in one timestep at a time by updating the state of the dlnetwork
. The predict
function allows you to produce the output prediction, along with the updated network state. This lstm_predict_and_update
function takes in a single-timestep input and updates the state of the network so that subsequent inputs are treated as subsequent timesteps of the same sample. After passing in all timesteps one at a time, the resulting output is the same as if all timesteps were passed in as a single input.
type('lstm_predict_and_update.m')
function out = lstm_predict_and_update(in) %#codegen % Copyright 2020-2024 The MathWorks, Inc. dlIn = dlarray(in,'CT'); persistent dlnet; if isempty(dlnet) dlnet = coder.loadDeepLearningNetwork('textClassifierNetwork.mat'); end [dlOut, updatedState] = predict(dlnet, dlIn); dlnet.State = updatedState; out = extractdata(dlOut); end
Generate code for lstm_predict_and_update
. Because this function accepts a single timestep at each call, specify matrixInput
to have a fixed sequence dimension of 1 instead of a variable sequence length.
matrixInput = coder.typeof(single(0),[1 1]); codegen -config cfg lstm_predict_and_update -args {matrixInput} -report
Code generation successful: View report
Run the generated MEX on the first observation.
sequenceLength = size(XNew{1},2); for i=1:sequenceLength inTimeStep = XNew{1}(:,i); YPred3 = lstm_predict_and_update_mex(single(inTimeStep)); end clear lstm_predict_and_update_mex;
Find the index that has the highest probability and map it to the labels.
[~, maxIndex] = max(YPred3); predictedLabels3 = labels(maxIndex); disp(predictedLabels3)
Leak
See Also
coder.DeepLearningConfig
(MATLAB Coder) | doc2sequence
(Text Analytics Toolbox) | coder.typeof
(MATLAB Coder) | codegen
(MATLAB Coder)
Related Topics
- Classify Text Data Using Deep Learning (Text Analytics Toolbox)
- Prerequisites for Deep Learning with MATLAB Coder (MATLAB Coder)