Main Content

本页翻译不是最新的。点击此处可查看最新英文版本。

使用卷积神经网络对文本数据进行分类

此示例说明如何使用卷积神经网络对文本数据进行分类。

要使用卷积对文本数据进行分类,请使用在输入的时间维度上进行卷积的一维卷积层。

此示例训练具有不同宽度的一维卷积滤波器的网络。每个滤波器的宽度对应于滤波器可以检测到的单词数(n 元分词长度)。网络有多个卷积层分支,因此它可以使用不同 n 元分词长度。

加载数据

根据 factoryReports.csv 中的数据创建一个表格文本数据存储,并查看前几个报告。

data = readtable("factoryReports.csv");
head(data)
ans=8×5 table
                                  Description                                         Category            Urgency            Resolution          Cost 
    _______________________________________________________________________    ______________________    __________    ______________________    _____

    {'Items are occasionally getting stuck in the scanner spools.'        }    {'Mechanical Failure'}    {'Medium'}    {'Readjust Machine'  }       45
    {'Loud rattling and banging sounds are coming from assembler pistons.'}    {'Mechanical Failure'}    {'Medium'}    {'Readjust Machine'  }       35
    {'There are cuts to the power when starting the plant.'               }    {'Electronic Failure'}    {'High'  }    {'Full Replacement'  }    16200
    {'Fried capacitors in the assembler.'                                 }    {'Electronic Failure'}    {'High'  }    {'Replace Components'}      352
    {'Mixer tripped the fuses.'                                           }    {'Electronic Failure'}    {'Low'   }    {'Add to Watch List' }       55
    {'Burst pipe in the constructing agent is spraying coolant.'          }    {'Leak'              }    {'High'  }    {'Replace Components'}      371
    {'A fuse is blown in the mixer.'                                      }    {'Electronic Failure'}    {'Low'   }    {'Replace Components'}      441
    {'Things continue to tumble off of the belt.'                         }    {'Mechanical Failure'}    {'Low'   }    {'Readjust Machine'  }       38

将数据划分为训练分区和验证分区。将 80% 的数据用于训练,其余的数据用于验证。

cvp = cvpartition(data.Category,Holdout=0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);

预处理文本数据

从表的 "Description" 列中提取文本数据,并使用示例的预处理文本函数一节中列出的 preprocessText 函数对其进行预处理。

documentsTrain = preprocessText(dataTrain.Description);

"Category" 列中提取标签,并将其转换为分类。

TTrain = categorical(dataTrain.Category);

查看类名称和观测值数目。

classNames = unique(TTrain)
classNames = 4×1 categorical
     Electronic Failure 
     Leak 
     Mechanical Failure 
     Software Failure 

numObservations = numel(TTrain)
numObservations = 384

使用相同的步骤提取和预处理验证数据。

documentsValidation = preprocessText(dataValidation.Description);
TValidation = categorical(dataValidation.Category);

将文档转换为序列

要将文档输入到神经网络中,请使用单词编码将文档转换为数值索引序列。

从文档创建单词编码。

enc = wordEncoding(documentsTrain);

查看单词编码的词汇量。词汇量是单词编码中具有唯一行的单词的数量。

numWords = enc.NumWords
numWords = 436

使用 doc2sequence 函数将文档转换为整数序列。

XTrain = doc2sequence(enc,documentsTrain);

使用从训练数据创建的单词编码将验证文档转换为序列。

XValidation = doc2sequence(enc,documentsValidation);

定义网络架构

为分类任务定义网络架构。

以下步骤说明如何定义网络架构。

  • 将输入大小指定为 1,它对应于整数序列输入的通道维度。

  • 使用维度为 100 的单词嵌入来嵌入输入。

  • 对于 n 元分词长度 2、3、4 和 5,创建包含卷积层、批量归一化层、ReLU 层、丢弃层和最大池化层的层模块。

  • 对于每个模块,指定 200 个大小为 1×N 的卷积滤波器和一个全局最大池化层。

  • 将输入层连接到每个模块,并使用串联层串联各模块的输出。

  • 要对输出进行分类,请包括一个输出大小为 K 的全连接层、一个 softmax 层和一个分类层,其中 K 是类的数量。

指定网络超参数。

embeddingDimension = 100;
ngramLengths = [2 3 4 5];
numFilters = 200;

首先,创建一个包含输入层和一个维度为 100 的单词嵌入层的层图。为了帮助将单词嵌入层连接到卷积层,请将单词嵌入层名称设置为 "emb"。要检查卷积层在训练期间不将序列卷积为零长度,请将 MinLength 选项设置为训练数据中最短序列的长度。

minLength = min(doclength(documentsTrain));
layers = [ 
    sequenceInputLayer(1,MinLength=minLength)
    wordEmbeddingLayer(embeddingDimension,numWords,Name="emb")];
lgraph = layerGraph(layers);

对于每个 n 元分词长度,创建一个由一维卷积层、批量归一化层、ReLU 层、丢弃层和一维全局最大池化层构成的模块。将每个模块连接到单词嵌入层。

numBlocks = numel(ngramLengths);
for j = 1:numBlocks
    N = ngramLengths(j);
    
    block = [
        convolution1dLayer(N,numFilters,Name="conv"+N,Padding="same")
        batchNormalizationLayer(Name="bn"+N)
        reluLayer(Name="relu"+N)
        dropoutLayer(0.2,Name="drop"+N)
        globalMaxPooling1dLayer(Name="max"+N)];
    
    lgraph = addLayers(lgraph,block);
    lgraph = connectLayers(lgraph,"emb","conv"+N);
end

添加串联层、全连接层、softmax 层和分类层。

numClasses = numel(classNames);

layers = [
    concatenationLayer(1,numBlocks,Name="cat")
    fullyConnectedLayer(numClasses,Name="fc")
    softmaxLayer(Name="soft")
    classificationLayer(Name="classification")];

lgraph = addLayers(lgraph,layers);

将全局最大池化层连接到串联层,并在绘图中查看网络架构。

for j = 1:numBlocks
    N = ngramLengths(j);
    lgraph = connectLayers(lgraph,"max"+N,"cat/in"+j);
end

figure
plot(lgraph)
title("Network Architecture")

训练网络

指定训练选项:

  • 使用小批量大小 128 进行训练。

  • 使用验证数据验证网络。

  • 返回验证损失最低的网络。

  • 显示训练进度图并隐藏详尽输出。

options = trainingOptions("adam", ...
    MiniBatchSize=128, ...
    ValidationData={XValidation,TValidation}, ...
    OutputNetwork="best-validation-loss", ...
    Plots="training-progress", ...
    Verbose=false);

使用 trainNetwork 函数训练网络。

net = trainNetwork(XTrain,TTrain,lgraph,options);

测试网络

使用经过训练的网络对验证数据进行分类。

YValidation = classify(net,XValidation);

在混淆图中可视化预测。

figure
confusionchart(TValidation,YValidation)

计算分类准确度。准确度是预测正确的标签的比例。

accuracy = mean(TValidation == YValidation)
accuracy = 0.9375

使用新数据进行预测

对三个新报告的事件类型进行分类。创建包含新报告的字符串数组。

reportsNew = [ 
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

使用与预处理训练和验证文档相同的步骤来预处理文本数据。

documentsNew = preprocessText(reportsNew);
XNew = doc2sequence(enc,documentsNew);

使用经过训练的网络对新序列进行分类。

YNew = classify(net,XNew)
YNew = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

预处理文本函数

preprocessTextData 函数接受文本数据作为输入并执行以下步骤:

  1. 对文本进行分词。

  2. 将文本转换为小写。

function documents = preprocessText(textData)

documents = tokenizedDocument(textData);
documents = lower(documents);

end

另请参阅

(Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | | | | | | (Text Analytics Toolbox) | (Text Analytics Toolbox) |

相关主题