使用深度学习对无法放入内存的文本数据进行分类
此示例说明在深度学习网络中,如何使用变换后的数据存储对无法放入内存的文本数据进行分类。
变换后的数据存储对从基础数据存储读取的数据进行变换或处理。您可以使用变换后的数据存储作为深度学习应用的训练数据集、验证数据集、测试数据集以及预测数据集的数据源。使用变换后的数据存储可读取无法放入内存的数据,或者在读取批量数据时执行特定的预处理操作。
在训练网络时,软件通过填充、截断或拆分输入数据来创建长度相同的小批量序列。trainingOptions
函数提供了填充和截断输入序列的选项,但是,这些选项不太适合单词向量序列。此外,此函数不支持在自定义数据存储中填充数据。您必须改为手动填充和截断序列。如果对单词向量序列进行左填充和截断,训练效果可能会得到改善。
Classify Text Data Using Deep Learning (Text Analytics Toolbox) 示例手动将所有文档截断和填充到相同长度。此过程为非常短的文档添加了大量填充,同时也丢弃了非常长的文档中的大量数据。
取而代之,为了防止添加过多填充或丢弃过多数据,请创建一个变换后的数据存储,以将小批量数据输入到网络中。此示例中创建的转换后的数据存储将小批量文档转换为序列或单词索引,并对每个小批量进行左填充,使其长度等于小批量中最长文档的长度。
加载预训练的单词嵌入
数据存储需要使用单词嵌入以将文档转换为向量序列。使用 fastTextWordEmbedding
加载预训练的单词嵌入。此函数需要 Text Analytics Toolbox™ Model for fastText English 16 Billion Token Word Embedding 支持包。如果未安装此支持包,则函数会提供下载链接。
emb = fastTextWordEmbedding;
加载数据
根据 factoryReports.csv
中的数据创建一个表格文本数据存储。指定仅读取 "Description"
和 "Category"
列中的数据。
filenameTrain = "factoryReports.csv"; textName = "Description"; labelName = "Category"; ttdsTrain = tabularTextDatastore(filenameTrain,'SelectedVariableNames',[textName labelName]);
查看数据存储的预览。
preview(ttdsTrain)
ans=8×2 table
Description Category
_______________________________________________________________________ ______________________
{'Items are occasionally getting stuck in the scanner spools.' } {'Mechanical Failure'}
{'Loud rattling and banging sounds are coming from assembler pistons.'} {'Mechanical Failure'}
{'There are cuts to the power when starting the plant.' } {'Electronic Failure'}
{'Fried capacitors in the assembler.' } {'Electronic Failure'}
{'Mixer tripped the fuses.' } {'Electronic Failure'}
{'Burst pipe in the constructing agent is spraying coolant.' } {'Leak' }
{'A fuse is blown in the mixer.' } {'Electronic Failure'}
{'Things continue to tumble off of the belt.' } {'Mechanical Failure'}
变换数据存储
创建一个自定义变换函数,该函数将从数据存储中读取的数据转换为包含预测变量和响应变量的表。transformText
函数获取从 tabularTextDatastore
对象读取的数据,并返回包含预测变量和响应变量的表。预测变量是由单词嵌入 emb
给出的 numFeatures
×lenSequence
单词向量数组,其中 numFeatures
是嵌入维数,lenSequence
是序列长度。响应变量是类的分类标签。
要获取类名,请使用在示例末尾列出的 readLabels
函数从训练数据中读取标签,并找出唯一的类名。
labels = readLabels(ttdsTrain,labelName); classNames = unique(labels); numObservations = numel(labels);
由于表格文本数据存储可以在一次读取中读取多行数据,因此您可以在变换函数中处理一个完整的小批量数据。为了确保变换函数处理完整的小批量数据,请将表格文本数据存储的读取大小设置为将用于训练的小批量大小。
miniBatchSize = 64; ttdsTrain.ReadSize = miniBatchSize;
要将表格文本数据的输出转换为训练序列,请使用 transform
函数变换数据存储。
tdsTrain = transform(ttdsTrain, @(data) transformText(data,emb,classNames))
tdsTrain = TransformedDatastore with properties: UnderlyingDatastores: {matlab.io.datastore.TabularTextDatastore} SupportedOutputFormats: ["txt" "csv" "dat" "asc" "xlsx" "xls" "parquet" "parq" "png" "jpg" "jpeg" "tif" "tiff" "wav" "flac" "ogg" "opus" "mp3" "mp4" "m4a"] Transforms: {@(data)transformText(data,emb,classNames)} IncludeInfo: 0
预览变换后的数据存储。预测变量是 numFeatures
×lenSequence
数组,其中 lenSequence
是序列长度,numFeatures
是特征数(嵌入维数)。响应变量是分类标签。
preview(tdsTrain)
ans=8×2 table
predictors responses
_______________ __________________
{300×11 single} Mechanical Failure
{300×11 single} Mechanical Failure
{300×11 single} Electronic Failure
{300×11 single} Electronic Failure
{300×11 single} Electronic Failure
{300×11 single} Leak
{300×11 single} Electronic Failure
{300×11 single} Mechanical Failure
创建和训练 LSTM 网络
定义 LSTM 网络架构。要将序列数据输入到网络中,请包含一个序列输入层并将输入大小设置为嵌入维度。接下来,添加一个具有 180 个隐含单元的 LSTM 层。要将该 LSTM 层用于“序列到标签”分类问题,请将输出模式设置为 'last'
。最后,添加一个输出大小等于类数的全连接层和一个 softmax 层。
numFeatures = emb.Dimension; numHiddenUnits = 180; numClasses = numel(classNames); layers = [ ... sequenceInputLayer(numFeatures) lstmLayer(numHiddenUnits,'OutputMode','last') fullyConnectedLayer(numClasses) softmaxLayer];
指定训练选项。在选项中进行选择需要经验分析。要通过运行试验探索不同训练选项配置,您可以使用Experiment Manager。
使用 Adam 优化器进行训练。
将输入数据格式设置为 'CTB'(通道、时间、批量)。
指定小批量大小。
将梯度阈值设置为
2
。数据存储不支持乱序,因此请将
'Shuffle'
设置为'never'
。在图中显示训练进度并监控准确度。
禁用详尽输出。
默认情况下,trainnet
使用 GPU(如果有)。要手动指定执行环境,请使用 trainingOptions
的 'ExecutionEnvironment'
名称-值对组参量。在 CPU 上进行训练所需的时间要明显长于在 GPU 上进行训练所需的时间。使用 GPU 进行训练需要 Parallel Computing Toolbox™ 和支持的 GPU 设备。有关受支持设备的信息,请参阅GPU 计算要求 (Parallel Computing Toolbox)。
numIterationsPerEpoch = floor(numObservations / miniBatchSize); options = trainingOptions('adam', ... 'MaxEpochs',15, ... 'InputDataFormats','CTB', ... 'MiniBatchSize',miniBatchSize, ... 'GradientThreshold',2, ... 'Shuffle','never', ... 'Plots','training-progress', ... 'Metrics','accuracy', ... 'Verbose',false);
使用 trainnet
函数训练神经网络。对于分类,使用交叉熵损失。
net = trainnet(tdsTrain,layers,"crossentropy",options);
使用新数据进行预测
对三个新报告的事件类型进行分类。创建包含新报告的字符串数组。
reportsNew = [ ... "Coolant is pooling underneath sorter." "Sorter blows fuses at start up." "There are some very loud rattling sounds coming from the assembler."];
使用与预处理训练文档相同的步骤来预处理文本数据。
documentsNew = preprocessText(reportsNew);
使用 doc2sequence
将文本数据转换为嵌入向量序列。
XNew = doc2sequence(emb,documentsNew);
使用经过训练的 LSTM 网络对新序列进行分类。
scores = minibatchpredict(net,XNew,InputDataFormats="CTB");
Y = scores2label(scores,classNames)
Y = 3×1 categorical
Leak
Electronic Failure
Mechanical Failure
变换文本函数
transformText
函数获取从 tabularTextDatastore
对象读取的数据,并返回包含预测变量和响应变量的表。预测变量是由单词嵌入 emb
给出的 numFeatures
×lenSequence
单词向量数组,其中 numFeatures
是嵌入维数,lenSequence
是序列长度。这些响应变量是 classNames
中的类的分类标签。
function dataTransformed = transformText(data,emb,classNames) % Preprocess documents. textData = data{:,1}; documents = preprocessText(textData); % Convert to sequences. predictors = doc2sequence(emb,documents); % Read labels. labels = data{:,2}; responses = categorical(labels,classNames); % Convert data to table. dataTransformed = table(predictors,responses); end
预处理函数
函数 preprocessText
执行以下步骤:
使用
tokenizedDocument
对文本进行分词。使用
lower
将文本转换为小写。使用
erasePunctuation
删除标点符号。
function documents = preprocessText(textData) documents = tokenizedDocument(textData); documents = lower(documents); documents = erasePunctuation(documents); end
读取标签函数
readLabels
函数创建 tabularTextDatastore
对象 ttds
的一个副本,并读取 labelName
列中的标签。
function labels = readLabels(ttds,labelName) ttdsNew = copy(ttds); ttdsNew.SelectedVariableNames = labelName; tbl = readall(ttdsNew); labels = tbl.(labelName); end
另请参阅
trainnet
| trainingOptions
| dlnetwork
| fastTextWordEmbedding
(Text Analytics Toolbox) | wordEmbeddingLayer
(Text Analytics Toolbox) | doc2sequence
(Text Analytics Toolbox) | tokenizedDocument
(Text Analytics Toolbox) | lstmLayer
| sequenceInputLayer
| transform