Main Content

本页翻译不是最新的。点击此处可查看最新英文版本。

使用深度学习对文本数据进行分类

此示例说明如何使用深度学习长短期记忆 (LSTM) 网络对文本数据进行分类。

文本数据本身就是有序的。一段文本是一个单词序列,这些单词之间可能存在依存关系。要学习和使用长期依存关系来对序列数据进行分类,请使用 LSTM 神经网络。LSTM 网络是一种循环神经网络 (RNN),可以学习序列数据的时间步之间的长期依存关系。

要将文本输入到 LSTM 网络,首先将文本数据转换为数值序列。您可以使用将文档映射为数值索引序列的单词编码来实现此目的。为了获得更好的结果,还要在网络中包含一个单词嵌入层。单词嵌入将词汇表中的单词映射为数值向量而不是标量索引。这些嵌入会捕获单词的语义细节,以便具有相似含义的单词具有相似的向量。它们还通过向量算术运算对单词之间的关系进行建模。例如,关系 "Rome is to Italy as Paris is to France" 通过公式 Italy Rome + Paris = France 进行描述。

在此示例中,训练和使用 LSTM 网络有四个步骤:

  • 导入并预处理数据。

  • 使用单词编码将单词转换为数值序列。

  • 创建和训练具有单词嵌入层的 LSTM 网络。

  • 使用经过训练的 LSTM 网络对新文本数据进行分类。

导入数据

导入工厂报告数据。该数据包含已标注的工厂事件文本描述。要将文本数据作为字符串导入,请将文本类型指定为 'string'

filename = "factoryReports.csv";
data = readtable(filename,'TextType','string');
head(data)
ans=8×5 table
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

此示例的目标是按 Category 列中的标签对事件进行分类。要将数据划分到各个类,请将这些标签转换为分类。

data.Category = categorical(data.Category);

使用直方图查看数据中类的分布。

figure
histogram(data.Category);
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

下一步是将其划分为训练集和验证集。将数据划分为训练分区和用于验证和测试的保留分区。将保留百分比指定为 20%。

cvp = cvpartition(data.Category,'Holdout',0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);

从分区后的表中提取文本数据和标签。

textDataTrain = dataTrain.Description;
textDataValidation = dataValidation.Description;
YTrain = dataTrain.Category;
YValidation = dataValidation.Category;

要检查是否已正确导入数据,请使用文字云将训练文本数据可视化。

figure
wordcloud(textDataTrain);
title("Training Data")

预处理文本数据

创建一个对文本数据进行分词和预处理的函数。在示例末尾列出的函数 preprocessText 执行以下步骤:

  1. 使用 tokenizedDocument 对文本进行分词。

  2. 使用 lower 将文本转换为小写。

  3. 使用 erasePunctuation 删除标点符号。

使用 preprocessText 函数预处理训练数据和验证数据。

documentsTrain = preprocessText(textDataTrain);
documentsValidation = preprocessText(textDataValidation);

查看前几个预处理的训练文档。

documentsTrain(1:5)
ans = 
  5×1 tokenizedDocument:

     9 tokens: items are occasionally getting stuck in the scanner spools
    10 tokens: loud rattling and banging sounds are coming from assembler pistons
    10 tokens: there are cuts to the power when starting the plant
     5 tokens: fried capacitors in the assembler
     4 tokens: mixer tripped the fuses

将文档转换为序列

要将文档输入到 LSTM 网络中,请使用单词编码将文档转换为数值索引序列。

要创建单词编码,请使用 wordEncoding 函数。

enc = wordEncoding(documentsTrain);

下一个转换步骤是填充和截断文档,使全部文档的长度相同。trainingOptions 函数提供了自动填充和截断输入序列的选项。但是,这些选项不太适合单词向量序列。请改为手动填充和截断序列。如果对单词向量序列进行左填充和截断,训练效果可能会得到改善。

要填充和截断文档,请先选择目标长度,然后对长于它的文档进行截断,并对短于它的文档进行左填充。为获得最佳结果,目标长度应该较短,但又不至于丢弃大量数据。要找到合适的目标长度,请查看训练文档长度的直方图。

documentLengths = doclength(documentsTrain);
figure
histogram(documentLengths)
title("Document Lengths")
xlabel("Length")
ylabel("Number of Documents")

大多数训练文档的词数少于 10 个。将此数字用作截断和填充的目标长度。

使用 doc2sequence 将文档转换为数值索引序列。要对长度为 10 的序列进行截断或左填充,请将 'Length' 选项设置为 10。

sequenceLength = 10;
XTrain = doc2sequence(enc,documentsTrain,'Length',sequenceLength);
XTrain(1:5)
ans=5×1 cell array
    {1×10 double}
    {1×10 double}
    {1×10 double}
    {1×10 double}
    {1×10 double}

使用相同选项将验证文档转换为序列。

XValidation = doc2sequence(enc,documentsValidation,'Length',sequenceLength);

创建和训练 LSTM 网络

定义 LSTM 网络架构。要将序列数据输入到网络中,请包含一个序列输入层并将输入大小设置为 1。接下来,包含一个维度为 50 且与单词编码具有相同单词数的单词嵌入层。然后,包含一个 LSTM 层并将隐含单元个数设置为 80。要将该 LSTM 层用于“序列到标签”分类问题,请将输出模式设置为 'last'。最后,添加一个大小与类数相同的全连接层、一个 softmax 层和一个分类层。

inputSize = 1;
embeddingDimension = 50;
numHiddenUnits = 80;

numWords = enc.NumWords;
numClasses = numel(categories(YTrain));

layers = [ ...
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(embeddingDimension,numWords)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  6x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 1 dimensions
     2   ''   Word Embedding Layer    Word embedding layer with 50 dimensions and 423 unique words
     3   ''   LSTM                    LSTM with 80 hidden units
     4   ''   Fully Connected         4 fully connected layer
     5   ''   Softmax                 softmax
     6   ''   Classification Output   crossentropyex

指定训练选项

指定训练选项:

  • 使用 Adam 求解器进行训练。

  • 指定小批量大小为 16。

  • 每轮训练都会打乱数据。

  • 通过将 'Plots' 选项设置为 'training-progress',监控训练进度。

  • 使用 'ValidationData' 选项指定验证数据。

  • 通过将 'Verbose' 选项设置为 false,隐藏详尽输出。

默认情况下,trainNetwork 使用 GPU(如果有)。否则将使用 CPU。要手动指定执行环境,请使用 trainingOptions'ExecutionEnvironment' 名称-值对组参数。在 CPU 上进行训练所需的时间要明显长于在 GPU 上进行训练所需的时间。使用 GPU 进行训练需要 Parallel Computing Toolbox™ 和支持的 GPU 设备。有关受支持设备的信息,请参阅GPU Computing Requirements (Parallel Computing Toolbox)

options = trainingOptions('adam', ...
    'MiniBatchSize',16, ...
    'GradientThreshold',2, ...
    'Shuffle','every-epoch', ...
    'ValidationData',{XValidation,YValidation}, ...
    'Plots','training-progress', ...
    'Verbose',false);

使用 trainNetwork 函数训练 LSTM 网络。

net = trainNetwork(XTrain,YTrain,layers,options);

使用新数据进行预测

对三个新报告的事件类型进行分类。创建包含新报告的字符串数组。

reportsNew = [ ...
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

使用与预处理训练文档相同的步骤来预处理文本数据。

documentsNew = preprocessText(reportsNew);

使用 doc2sequence 将文本数据转换为序列,所用选项与创建训练序列时的选项相同。

XNew = doc2sequence(enc,documentsNew,'Length',sequenceLength);

使用经过训练的 LSTM 网络对新序列进行分类。

labelsNew = classify(net,XNew)
labelsNew = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

预处理函数

函数 preprocessText 执行以下步骤:

  1. 使用 tokenizedDocument 对文本进行分词。

  2. 使用 lower 将文本转换为小写。

  3. 使用 erasePunctuation 删除标点符号。

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Convert to lowercase.
documents = lower(documents);

% Erase punctuation.
documents = erasePunctuation(documents);

end

另请参阅

(Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | | | | (Text Analytics Toolbox) | | (Text Analytics Toolbox)

相关主题