Main Content

本页的翻译已过时。点击此处可查看最新英文版本。

使用深度学习对无法放入内存的文本数据进行分类

此示例说明在深度学习网络中,如何使用变换后的数据存储对无法放入内存的文本数据进行分类。

变换后的数据存储对从基础数据存储读取的数据进行变换或处理。您可以使用变换后的数据存储作为深度学习应用的训练数据集、验证数据集、测试数据集以及预测数据集的数据源。使用变换后的数据存储可读取无法放入内存的数据,或者在读取批量数据时执行特定的预处理操作。

在训练网络时,软件通过填充、截断或拆分输入数据来创建长度相同的小批量序列。trainingOptions 函数提供了填充和截断输入序列的选项,但是,这些选项不太适合单词向量序列。此外,此函数不支持在自定义数据存储中填充数据。您必须改为手动填充和截断序列。如果对单词向量序列进行左填充和截断,训练效果可能会得到改善。

Classify Text Data Using Deep Learning (Text Analytics Toolbox) 示例手动将所有文档截断和填充到相同长度。此过程为非常短的文档添加了大量填充,同时也丢弃了非常长的文档中的大量数据。

取而代之,为了防止添加过多填充或丢弃过多数据,请创建一个变换后的数据存储,以将小批量数据输入到网络中。此示例中创建的转换后的数据存储将小批量文档转换为序列或单词索引,并对每个小批量进行左填充,使其长度等于小批量中最长文档的长度。

加载预训练的单词嵌入

数据存储需要使用单词嵌入以将文档转换为向量序列。使用 fastTextWordEmbedding 加载预训练的单词嵌入。此函数需要 Text Analytics Toolbox™ Model for fastText English 16 Billion Token Word Embedding 支持包。如果未安装此支持包,则函数会提供下载链接。

emb = fastTextWordEmbedding;

加载数据

根据 factoryReports.csv 中的数据创建一个表格文本数据存储。指定仅读取 "Description""Category" 列中的数据。

filenameTrain = "factoryReports.csv";
textName = "Description";
labelName = "Category";
ttdsTrain = tabularTextDatastore(filenameTrain,'SelectedVariableNames',[textName labelName]);

查看数据存储的预览。

preview(ttdsTrain)
ans=8×2 table
                                  Description                                         Category       
    _______________________________________________________________________    ______________________

    {'Items are occasionally getting stuck in the scanner spools.'        }    {'Mechanical Failure'}
    {'Loud rattling and banging sounds are coming from assembler pistons.'}    {'Mechanical Failure'}
    {'There are cuts to the power when starting the plant.'               }    {'Electronic Failure'}
    {'Fried capacitors in the assembler.'                                 }    {'Electronic Failure'}
    {'Mixer tripped the fuses.'                                           }    {'Electronic Failure'}
    {'Burst pipe in the constructing agent is spraying coolant.'          }    {'Leak'              }
    {'A fuse is blown in the mixer.'                                      }    {'Electronic Failure'}
    {'Things continue to tumble off of the belt.'                         }    {'Mechanical Failure'}

变换数据存储

创建一个自定义变换函数,该函数将从数据存储中读取的数据转换为包含预测变量和响应的表。transformText 函数获取从 tabularTextDatastore 对象读取的数据,并返回包含预测变量和响应的表。预测变量是由单词嵌入 emb 给出的 C×S 单词向量数组,其中 C 是嵌入维数,S 是序列长度。响应是类的分类标签。

要获取类名,请使用在示例末尾列出的 readLabels 函数从训练数据中读取标签,并找出唯一的类名。

labels = readLabels(ttdsTrain,labelName);
classNames = unique(labels);
numObservations = numel(labels);

由于表格文本数据存储可以在一次读取中读取多行数据,因此您可以在变换函数中处理一个完整的小批量数据。为了确保变换函数处理完整的小批量数据,请将表格文本数据存储的读取大小设置为将用于训练的小批量大小。

miniBatchSize = 64;
ttdsTrain.ReadSize = miniBatchSize;

要将表格文本数据的输出转换为训练序列,请使用 transform 函数变换数据存储。

tdsTrain = transform(ttdsTrain, @(data) transformText(data,emb,classNames))
tdsTrain = 
  TransformedDatastore with properties:

       UnderlyingDatastore: [1×1 matlab.io.datastore.TabularTextDatastore]
    SupportedOutputFormats: ["txt"    "csv"    "xlsx"    "xls"    "parquet"    "parq"    "png"    "jpg"    "jpeg"    "tif"    "tiff"    "wav"    "flac"    "ogg"    "mp4"    "m4a"]
                Transforms: {@(data)transformText(data,emb,classNames)}
               IncludeInfo: 0

变换后的数据存储的预览。预测变量是 C×S 数组,其中 S 是序列长度,C 是特征数(嵌入维数)。响应是分类标签。

preview(tdsTrain)
ans=8×2 table
      predictors           responses     
    _______________    __________________

    {300×11 single}    Mechanical Failure
    {300×11 single}    Mechanical Failure
    {300×11 single}    Electronic Failure
    {300×11 single}    Electronic Failure
    {300×11 single}    Electronic Failure
    {300×11 single}    Leak              
    {300×11 single}    Electronic Failure
    {300×11 single}    Mechanical Failure

创建和训练 LSTM 网络

定义 LSTM 网络架构。要将序列数据输入到网络中,请包含一个序列输入层并将输入大小设置为嵌入维度。接下来,添加一个具有 180 个隐含单元的 LSTM 层。要将该 LSTM 层用于“序列到标签”分类问题,请将输出模式设置为 'last'。最后,添加一个输出大小等于类数的全连接层、一个 softmax 层和一个分类层。

numFeatures = emb.Dimension;
numHiddenUnits = 180;
numClasses = numel(classNames);
layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

指定训练选项。指定求解器为 'adam' ,梯度阈值为 2。数据存储不支持乱序,因此请将 'Shuffle' 设置为 'never'。每轮训练后对网络进行一次验证。要监控训练进度,请将 'Plots' 选项设置为 'training-progress'。要隐藏详细输出,请将 'Verbose' 设置为 false

默认情况下,如果有 GPU 可用,trainNetwork 就会使用 GPU(需要 Parallel Computing Toolbox™ 和具有 3.0 或更高计算能力的支持 CUDA® 的 GPU)。否则将使用 CPU。要手动指定执行环境,请使用 trainingOptions'ExecutionEnvironment' 名称-值对组参数。在 CPU 上进行训练所需的时间要明显长于在 GPU 上进行训练所需的时间。

numIterationsPerEpoch = floor(numObservations / miniBatchSize);

options = trainingOptions('adam', ...
    'MaxEpochs',15, ...
    'MiniBatchSize',miniBatchSize, ...
    'GradientThreshold',2, ...
    'Shuffle','never', ...
    'Plots','training-progress', ...
    'Verbose',false);

使用 trainNetwork 函数训练 LSTM 网络。

net = trainNetwork(tdsTrain,layers,options);

使用新数据进行预测

对三个新报告的事件类型进行分类。创建包含新报告的字符串数组。

reportsNew = [ ...
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

使用与预处理训练文档相同的步骤来预处理文本数据。

documentsNew = preprocessText(reportsNew);

使用 doc2sequence 将文本数据转换为嵌入向量序列。

XNew = doc2sequence(emb,documentsNew);

使用经过训练的 LSTM 网络对新序列进行分类。

labelsNew = classify(net,XNew)
labelsNew = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

变换文本函数

transformText 函数获取从 tabularTextDatastore 对象读取的数据,并返回包含预测变量和响应的表。预测变量是由单词嵌入 emb 给出的 C×S 单词向量数组,其中 C 是嵌入维数,S 是序列长度。这些响应是 classNames 中的类的分类标签。

function dataTransformed = transformText(data,emb,classNames)

% Preprocess documents.
textData = data{:,1};
documents = preprocessText(textData);

% Convert to sequences.
predictors = doc2sequence(emb,documents);

% Read labels.
labels = data{:,2};
responses = categorical(labels,classNames);

% Convert data to table.
dataTransformed = table(predictors,responses);

end

预处理函数

函数 preprocessText 执行以下步骤:

  1. 使用 tokenizedDocument 对文本进行分词。

  2. 使用 lower 将文本转换为小写。

  3. 使用 erasePunctuation 删除标点符号。

function documents = preprocessText(textData)

documents = tokenizedDocument(textData);
documents = lower(documents);
documents = erasePunctuation(documents);

end

读取标签函数

readLabels 函数创建 tabularTextDatastore 对象 ttds 的一个副本,并读取 labelName 列中的标签。

function labels = readLabels(ttds,labelName)

ttdsNew = copy(ttds);
ttdsNew.SelectedVariableNames = labelName;
tbl = readall(ttdsNew);
labels = tbl.(labelName);

end

另请参阅

| | | | | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox)

相关主题