本页面提供的是上一版软件的文档。当前版本中已删除对应的英文页面。

使用深度学习对文本数据进行分类

此示例说明如何使用深度学习长短期记忆 (LSTM) 网络对天气报告的文本描述进行分类。

文本数据本身就是有序的。一段文本是一个单词序列,这些单词之间可能存在依存关系。要学习和使用长期依存关系来对序列数据进行分类,请使用 LSTM 神经网络。LSTM 网络是一种循环神经网络 (RNN),可以学习序列数据的时间步之间的长期依存关系。

要将文本输入到 LSTM 网络,首先将文本数据转换为数值序列。您可以使用将文档映射为数值索引序列的单词编码来实现此目的。为了获得更好的结果,还要在网络中包含一个单词嵌入层。单词嵌入将词汇表中的单词映射为数值向量而不是标量索引。这些嵌入会捕获单词的语义细节,以便具有相似含义的单词具有相似的向量。它们还通过向量算术运算对单词之间的关系进行建模。例如,关系 "king is to queen as man is to woman" 通过公式 king man + woman = queen 进行描述。

在此示例中,训练和使用 LSTM 网络有四个步骤:

  • 导入并预处理数据。

  • 使用单词编码将单词转换为数值序列。

  • 创建和训练具有单词嵌入层的 LSTM 网络。

  • 使用经过训练的 LSTM 网络对新文本数据进行分类。

导入数据

导入天气报告数据。该数据包含带标签的天气事件文本描述。要将文本数据作为字符串导入,请将文本类型指定为 'string'

filename = "weatherReports.csv";
data = readtable(filename,'TextType','string');
head(data)
ans=8×16 table
            Time             event_id          state              event_type         damage_property    damage_crops    begin_lat    begin_lon    end_lat    end_lon                                                                                             event_narrative                                                                                             storm_duration    begin_day    end_day    year       end_timestamp    
    ____________________    __________    ________________    ___________________    _______________    ____________    _________    _________    _______    _______    _________________________________________________________________________________________________________________________________________________________________________________________________    ______________    _________    _______    ____    ____________________

    22-Jul-2016 16:10:00    6.4433e+05    "MISSISSIPPI"       "Thunderstorm Wind"       ""                "0.00K"         34.14        -88.63     34.122     -88.626    "Large tree down between Plantersville and Nettleton."                                                                                                                                                  00:05:00          22          22       2016    22-Jul-0016 16:15:00
    15-Jul-2016 17:15:00    6.5182e+05    "SOUTH CAROLINA"    "Heavy Rain"              "2.00K"           "0.00K"         34.94        -81.03      34.94      -81.03    "One to two feet of deep standing water developed on a street on the Winthrop University campus after more than an inch of rain fell in less than an hour. One vehicle was stalled in the water."       00:00:00          15          15       2016    15-Jul-0016 17:15:00
    15-Jul-2016 17:25:00    6.5183e+05    "SOUTH CAROLINA"    "Thunderstorm Wind"       "0.00K"           "0.00K"         35.01        -80.93      35.01      -80.93    "NWS Columbia relayed a report of trees blown down along Tom Hall St."                                                                                                                                  00:00:00          15          15       2016    15-Jul-0016 17:25:00
    16-Jul-2016 12:46:00    6.5183e+05    "NORTH CAROLINA"    "Thunderstorm Wind"       "0.00K"           "0.00K"         35.64        -82.14      35.64      -82.14    "Media reported two trees blown down along I-40 in the Old Fort area."                                                                                                                                  00:00:00          16          16       2016    16-Jul-0016 12:46:00
    15-Jul-2016 14:28:00    6.4332e+05    "MISSOURI"          "Hail"                    ""                ""              36.45        -89.97      36.45      -89.97    ""                                                                                                                                                                                                      00:07:00          15          15       2016    15-Jul-0016 14:35:00
    15-Jul-2016 16:31:00    6.4332e+05    "ARKANSAS"          "Thunderstorm Wind"       ""                "0.00K"         35.85         -90.1     35.838     -90.087    "A few tree limbs greater than 6 inches down on HWY 18 in Roseland."                                                                                                                                    00:09:00          15          15       2016    15-Jul-0016 16:40:00
    15-Jul-2016 16:03:00    6.4343e+05    "TENNESSEE"         "Thunderstorm Wind"       "20.00K"          "0.00K"        35.056       -89.937      35.05     -89.904    "Awning blown off a building on Lamar Avenue. Multiple trees down near the intersection of Winchester and Perkins."                                                                                     00:07:00          15          15       2016    15-Jul-0016 16:10:00
    15-Jul-2016 17:27:00    6.4344e+05    "TENNESSEE"         "Hail"                    ""                ""             35.385        -89.78     35.385      -89.78    "Quarter size hail near Rosemark."                                                                                                                                                                      00:05:00          15          15       2016    15-Jul-0016 17:32:00

删除具有空报告的表行。

idxEmpty = strlength(data.event_narrative) == 0;
data(idxEmpty,:) = [];

此示例的目标是按 event_type 列中的标签对事件进行分类。要将数据划分到各个类,请将这些标签转换为分类。

data.event_type = categorical(data.event_type);

使用直方图查看数据中类的分布。为了使标签更易于阅读,请增加图窗的宽度。

f = figure;
f.Position(3) = 1.5*f.Position(3);

h = histogram(data.event_type);
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

数据的类是不均衡的,许多类只包含很少的观测值。当类像这样呈现不均衡时,网络可能会收敛到不太准确的模型。为防止出现此问题,请删除任何出现少于十次的类。

从直方图中获取类和类名的频率计数。

classCounts = h.BinCounts;
classNames = h.Categories;

查找包含少于十个观测值的类。

idxLowCounts = classCounts < 10;
infrequentClasses = classNames(idxLowCounts)
infrequentClasses = 1×8 cell array
    {'Freezing Fog'}    {'Hurricane'}    {'Lakeshore Flood'}    {'Marine Dense Fog'}    {'Marine Strong Wind'}    {'Marine Tropical Depression'}    {'Seiche'}    {'Sneakerwave'}

从数据中删除这些稀少的类。使用 removecats 从分类数据中删除未使用的分类。

idxInfrequent = ismember(data.event_type,infrequentClasses);
data(idxInfrequent,:) = [];
data.event_type = removecats(data.event_type);

现在,数据已分类到大小合理的类中。下一步是将其划分为训练集、验证集和测试集。将数据划分为训练分区和用于验证和测试的保留分区。将保留百分比指定为 30%。

cvp = cvpartition(data.event_type,'Holdout',0.3);
dataTrain = data(training(cvp),:);
dataHeldOut = data(test(cvp),:);

再次划分保留集以获得验证集。将保留百分比指定为 50%。由此得到 70% 训练观测值、15% 验证观测值和 15% 测试观测值的划分。

cvp = cvpartition(dataHeldOut.event_type,'HoldOut',0.5);
dataValidation = dataHeldOut(training(cvp),:);
dataTest = dataHeldOut(test(cvp),:);

从分区后的表中提取文本数据和标签。

textDataTrain = dataTrain.event_narrative;
textDataValidation = dataValidation.event_narrative;
textDataTest = dataTest.event_narrative;
YTrain = dataTrain.event_type;
YValidation = dataValidation.event_type;
YTest = dataTest.event_type;

要检查是否已正确导入数据,请使用文字云将训练文本数据可视化。

figure
wordcloud(textDataTrain);
title("Training Data")

预处理文本数据

创建一个对文本数据进行分词和预处理的函数。在示例末尾列出的函数 preprocessText 执行以下步骤:

  1. 使用 tokenizedDocument 对文本进行分词。

  2. 使用 lower 将文本转换为小写。

  3. 使用 erasePunctuation 删除标点符号。

使用 preprocessText 函数预处理训练数据和验证数据。

documentsTrain = preprocessText(textDataTrain);
documentsValidation = preprocessText(textDataValidation);

查看前几个预处理的训练文档。

documentsTrain(1:5)
ans = 
  5×1 tokenizedDocument:

     7 tokens: large tree down between plantersville and nettleton
    37 tokens: one to two feet of deep standing water developed on a street on the winthrop university campus after more than an inch of rain fell in less than an hour one vehicle was stalled in the water
    13 tokens: nws columbia relayed a report of trees blown down along tom hall st
    13 tokens: media reported two trees blown down along i40 in the old fort area
    14 tokens: a few tree limbs greater than 6 inches down on hwy 18 in roseland

将文档转换为序列

要将文档输入到 LSTM 网络中,请使用单词编码将文档转换为数值索引序列。

要创建单词编码,请使用 wordEncoding 函数。

enc = wordEncoding(documentsTrain);

下一个转换步骤是填充和截断文档,使全部文档的长度相同。trainingOptions 函数提供了自动填充和截断输入序列的选项。但是,这些选项不太适合单词向量序列。请改为手动填充和截断序列。如果对单词向量序列进行左填充和截断,训练效果可能会得到改善。

要填充和截断文档,请先选择目标长度,然后对长于它的文档进行截断,并对短于它的文档进行左填充。为获得最佳结果,目标长度应该较短,但又不至于丢弃大量数据。要找到合适的目标长度,请查看训练文档长度的直方图。

documentLengths = doclength(documentsTrain);
figure
histogram(documentLengths)
title("Document Lengths")
xlabel("Length")
ylabel("Number of Documents")

大多数训练文档的词数少于 75 个。将此数字用作截断和填充的目标长度。

使用 doc2sequence 将文档转换为数值索引序列。要对长度为 75 的序列进行截断或左填充,请将 'Length' 选项设置为 75。

XTrain = doc2sequence(enc,documentsTrain,'Length',75);
XTrain(1:5)
ans=5×1 cell
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}

使用相同选项将验证文档转换为序列。

XValidation = doc2sequence(enc,documentsValidation,'Length',75);

创建和训练 LSTM 网络

定义 LSTM 网络架构。要将序列数据输入到网络中,请包含一个序列输入层并将输入大小设置为 1。接下来,包含一个维度为 100 且与单词编码具有相同单词数的单词嵌入层。然后,包含一个 LSTM 层并将隐含单元个数设置为 180。要将该 LSTM 层用于“序列到标签”分类问题,请将输出模式设置为 'last'。最后,添加一个大小与类数相同的全连接层、一个 softmax 层和一个分类层。

inputSize = 1;
embeddingDimension = 100;
numWords = enc.NumWords;
numHiddenUnits = 180;
numClasses = numel(categories(YTrain));

layers = [ ...
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(embeddingDimension,numWords)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  6x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 1 dimensions
     2   ''   Word Embedding Layer    Word embedding layer with 100 dimensions and 16954 unique words
     3   ''   LSTM                    LSTM with 180 hidden units
     4   ''   Fully Connected         39 fully connected layer
     5   ''   Softmax                 softmax
     6   ''   Classification Output   crossentropyex

指定训练选项。将求解器设置为 'adam',进行 10 轮训练,并将梯度阈值设置为 1。将初始学习率设置为 0.01。要监控训练进度,请将 'Plots' 选项设置为 'training-progress'。使用 'ValidationData' 选项指定验证数据。要隐藏详细输出,请将 'Verbose' 设置为 false

默认情况下,如果有 GPU 可用,trainNetwork 就会使用 GPU(需要 Parallel Computing Toolbox™ 和具有 3.0 或更高计算能力的支持 CUDA® 的 GPU)。否则将使用 CPU。要手动指定执行环境,请使用 trainingOptions'ExecutionEnvironment' 名称-值对组参数。在 CPU 上进行训练所需的时间要明显长于在 GPU 上进行训练所需的时间。

options = trainingOptions('adam', ...
    'MaxEpochs',10, ...    
    'GradientThreshold',1, ...
    'InitialLearnRate',0.01, ...
    'ValidationData',{XValidation,YValidation}, ...
    'Plots','training-progress', ...
    'Verbose',false);

使用 trainNetwork 函数训练 LSTM 网络。

net = trainNetwork(XTrain,YTrain,layers,options);

测试 LSTM 网络

要测试 LSTM 网络,请先使用与准备训练数据相同的方式准备测试数据。然后使用经过训练的 LSTM 网络 net 对预处理的测试数据进行预测。

使用与预处理训练文档相同的步骤来预处理测试数据。

textDataTest = lower(textDataTest);
documentsTest = tokenizedDocument(textDataTest);
documentsTest = erasePunctuation(documentsTest);

使用 doc2sequence 将测试文档转换为序列,所用选项与创建训练序列时的选项相同。

XTest = doc2sequence(enc,documentsTest,'Length',75);
XTest(1:5)
ans=5×1 cell
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}

使用经过训练的 LSTM 网络对测试文档进行分类。

YPred = classify(net,XTest);

计算分类准确度。准确度是网络预测正确的标签的比例。

accuracy = sum(YPred == YTest)/numel(YPred)
accuracy = 0.8684

使用新数据进行预测

对三个新天气报告的事件类型进行分类。创建包含新天气报告的字符串数组。

reportsNew = [ ...
    "Lots of water damage to computer equipment inside the office."
    "A large tree is downed and blocking traffic outside Apple Hill."
    "Damage to many car windshields in parking lot."];

使用与预处理训练文档相同的步骤来预处理文本数据。

documentsNew = preprocessText(reportsNew);

使用 doc2sequence 将文本数据转换为序列,所用选项与创建训练序列时的选项相同。

XNew = doc2sequence(enc,documentsNew,'Length',75);

使用经过训练的 LSTM 网络对新序列进行分类。

[labelsNew,score] = classify(net,XNew);

显示天气报告及其预测标签。

[reportsNew string(labelsNew)]
ans = 3×2 string array
    "Lots of water damage to computer equipment inside the office."      "Flash Flood"      
    "A large tree is downed and blocking traffic outside Apple Hill."    "Thunderstorm Wind"
    "Damage to many car windshields in parking lot."                     "Hail"             

预处理函数

函数 preprocessText 执行以下步骤:

  1. 使用 tokenizedDocument 对文本进行分词。

  2. 使用 lower 将文本转换为小写。

  3. 使用 erasePunctuation 删除标点符号。

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Convert to lowercase.
documents = lower(documents);

% Erase punctuation.
documents = erasePunctuation(documents);

end

另请参阅

| | | | | | | |

相关主题