使用卷积神经网络对文本数据进行分类
此示例说明如何使用卷积神经网络对文本数据进行分类。
要使用卷积对文本数据进行分类,请使用在输入的时间维度上进行卷积的一维卷积层。
此示例训练具有不同宽度的一维卷积滤波器的网络。每个滤波器的宽度对应于滤波器可以检测到的单词数(n 元分词长度)。网络有多个卷积层分支,因此它可以使用不同 n 元分词长度。
加载数据
根据 factoryReports.csv
中的数据创建一个表格文本数据存储,并查看前几个报告。
data = readtable("factoryReports.csv");
head(data)
ans=8×5 table
Description Category Urgency Resolution Cost
_______________________________________________________________________ ______________________ __________ ______________________ _____
{'Items are occasionally getting stuck in the scanner spools.' } {'Mechanical Failure'} {'Medium'} {'Readjust Machine' } 45
{'Loud rattling and banging sounds are coming from assembler pistons.'} {'Mechanical Failure'} {'Medium'} {'Readjust Machine' } 35
{'There are cuts to the power when starting the plant.' } {'Electronic Failure'} {'High' } {'Full Replacement' } 16200
{'Fried capacitors in the assembler.' } {'Electronic Failure'} {'High' } {'Replace Components'} 352
{'Mixer tripped the fuses.' } {'Electronic Failure'} {'Low' } {'Add to Watch List' } 55
{'Burst pipe in the constructing agent is spraying coolant.' } {'Leak' } {'High' } {'Replace Components'} 371
{'A fuse is blown in the mixer.' } {'Electronic Failure'} {'Low' } {'Replace Components'} 441
{'Things continue to tumble off of the belt.' } {'Mechanical Failure'} {'Low' } {'Readjust Machine' } 38
将数据划分为训练分区和验证分区。将 80% 的数据用于训练,其余的数据用于验证。
cvp = cvpartition(data.Category,Holdout=0.2); dataTrain = data(training(cvp),:); dataValidation = data(test(cvp),:);
预处理文本数据
从表的 "Description"
列中提取文本数据,并使用示例的预处理文本函数一节中列出的 preprocessText
函数对其进行预处理。
documentsTrain = preprocessText(dataTrain.Description);
从 "Category"
列中提取标签,并将其转换为分类。
TTrain = categorical(dataTrain.Category);
查看类名称和观测值数目。
classNames = unique(TTrain)
classNames = 4×1 categorical
Electronic Failure
Leak
Mechanical Failure
Software Failure
numObservations = numel(TTrain)
numObservations = 384
使用相同的步骤提取和预处理验证数据。
documentsValidation = preprocessText(dataValidation.Description); TValidation = categorical(dataValidation.Category);
将文档转换为序列
要将文档输入到神经网络中,请使用单词编码将文档转换为数值索引序列。
从文档创建单词编码。
enc = wordEncoding(documentsTrain);
查看单词编码的词汇量。词汇量是单词编码中具有唯一行的单词的数量。
numWords = enc.NumWords
numWords = 436
使用 doc2sequence
函数将文档转换为整数序列。
XTrain = doc2sequence(enc,documentsTrain);
使用从训练数据创建的单词编码将验证文档转换为序列。
XValidation = doc2sequence(enc,documentsValidation);
定义网络架构
为分类任务定义网络架构。
以下步骤说明如何定义网络架构。
将输入大小指定为 1,它对应于整数序列输入的通道维度。
使用维度为 100 的单词嵌入来嵌入输入。
对于 n 元分词长度 2、3、4 和 5,创建包含卷积层、批量归一化层、ReLU 层、丢弃层和最大池化层的层模块。
对于每个模块,指定 200 个大小为 1×N 的卷积滤波器和一个全局最大池化层。
将输入层连接到每个模块,并使用串联层串联各模块的输出。
要对输出进行分类,请包括一个输出大小为 K 的全连接层、一个 softmax 层和一个分类层,其中 K 是类的数量。
指定网络超参数。
embeddingDimension = 100; ngramLengths = [2 3 4 5]; numFilters = 200;
首先,创建一个包含输入层和一个维度为 100 的单词嵌入层的层图。为了帮助将单词嵌入层连接到卷积层,请将单词嵌入层名称设置为 "emb"
。要检查卷积层在训练期间不将序列卷积为零长度,请将 MinLength
选项设置为训练数据中最短序列的长度。
minLength = min(doclength(documentsTrain));
layers = [
sequenceInputLayer(1,MinLength=minLength)
wordEmbeddingLayer(embeddingDimension,numWords,Name="emb")];
lgraph = layerGraph(layers);
对于每个 n 元分词长度,创建一个由一维卷积层、批量归一化层、ReLU 层、丢弃层和一维全局最大池化层构成的模块。将每个模块连接到单词嵌入层。
numBlocks = numel(ngramLengths); for j = 1:numBlocks N = ngramLengths(j); block = [ convolution1dLayer(N,numFilters,Name="conv"+N,Padding="same") batchNormalizationLayer(Name="bn"+N) reluLayer(Name="relu"+N) dropoutLayer(0.2,Name="drop"+N) globalMaxPooling1dLayer(Name="max"+N)]; lgraph = addLayers(lgraph,block); lgraph = connectLayers(lgraph,"emb","conv"+N); end
添加串联层、全连接层、softmax 层和分类层。
numClasses = numel(classNames); layers = [ concatenationLayer(1,numBlocks,Name="cat") fullyConnectedLayer(numClasses,Name="fc") softmaxLayer(Name="soft") classificationLayer(Name="classification")]; lgraph = addLayers(lgraph,layers);
将全局最大池化层连接到串联层,并在绘图中查看网络架构。
for j = 1:numBlocks N = ngramLengths(j); lgraph = connectLayers(lgraph,"max"+N,"cat/in"+j); end figure plot(lgraph) title("Network Architecture")
训练网络
指定训练选项:
使用小批量大小 128 进行训练。
使用验证数据验证网络。
返回验证损失最低的网络。
显示训练进度图并隐藏详尽输出。
options = trainingOptions("adam", ... MiniBatchSize=128, ... ValidationData={XValidation,TValidation}, ... OutputNetwork="best-validation-loss", ... Plots="training-progress", ... Verbose=false);
使用 trainNetwork
函数训练网络。
net = trainNetwork(XTrain,TTrain,lgraph,options);
测试网络
使用经过训练的网络对验证数据进行分类。
YValidation = classify(net,XValidation);
在混淆图中可视化预测。
figure confusionchart(TValidation,YValidation)
计算分类准确度。准确度是预测正确的标签的比例。
accuracy = mean(TValidation == YValidation)
accuracy = 0.9375
使用新数据进行预测
对三个新报告的事件类型进行分类。创建包含新报告的字符串数组。
reportsNew = [ "Coolant is pooling underneath sorter." "Sorter blows fuses at start up." "There are some very loud rattling sounds coming from the assembler."];
使用与预处理训练和验证文档相同的步骤来预处理文本数据。
documentsNew = preprocessText(reportsNew); XNew = doc2sequence(enc,documentsNew);
使用经过训练的网络对新序列进行分类。
YNew = classify(net,XNew)
YNew = 3×1 categorical
Leak
Electronic Failure
Mechanical Failure
预处理文本函数
preprocessTextData
函数接受文本数据作为输入并执行以下步骤:
对文本进行分词。
将文本转换为小写。
function documents = preprocessText(textData) documents = tokenizedDocument(textData); documents = lower(documents); end
另请参阅
fastTextWordEmbedding
(Text Analytics Toolbox) | wordcloud
(Text Analytics Toolbox) | wordEmbedding
(Text Analytics Toolbox) | layerGraph
| convolution2dLayer
| batchNormalizationLayer
| trainingOptions
| trainNetwork
| doc2sequence
(Text Analytics Toolbox) | tokenizedDocument
(Text Analytics Toolbox) | transform
相关主题
- Classify Text Data Using Deep Learning (Text Analytics Toolbox)
- Create Simple Text Model for Classification (Text Analytics Toolbox)
- Analyze Text Data Using Topic Models (Text Analytics Toolbox)
- Analyze Text Data Using Multiword Phrases (Text Analytics Toolbox)
- Train a Sentiment Classifier (Text Analytics Toolbox)
- 使用深度学习进行序列分类
- Datastores for Deep Learning
- 在 MATLAB 中进行深度学习