Main Content

本页的翻译已过时。点击此处可查看最新英文版本。

使用卷积神经网络对文本数据进行分类

此示例说明如何使用卷积神经网络对文本数据进行分类。

要使用卷积对文本数据进行分类,必须将文本数据转换为图像。为此,请填充或截断观测值,使其具有恒定长度 S,并使用单词嵌入将文档转换为长度为 C 的单词向量序列。然后,您可以将文档表示为 1×S×C 的图像(高度为 1、宽度为 S 且具有 C 个通道的图像)。

要将 CSV 文件中的文本数据转换为图像,请创建一个 tabularTextDatastore 对象。通过使用自定义变换函数调用 transform,将从 tabularTextDatastore 对象读取的数据转换为图像以便进行深度学习。在示例末尾列出的 transformTextData 函数接受从数据存储中读取的数据和一个预训练的单词嵌入,并将每个观测值转换为单词向量数组。

此示例训练具有不同宽度的一维卷积滤波器的网络。每个滤波器的宽度对应于滤波器可以检测到的单词数(n 元分词长度)。网络有多个卷积层分支,因此它可以使用不同 n 元分词长度。

加载预训练的单词嵌入

加载预训练的 fastText 单词嵌入。此函数需要 Text Analytics Toolbox™ Model for fastText English 16 Billion Token Word Embedding 支持包。如果未安装此支持包,则函数会提供下载链接。

emb = fastTextWordEmbedding;

加载数据

根据 factoryReports.csv 中的数据创建一个表格文本数据存储。仅读取 "Description""Category" 列中的数据。

filenameTrain = "factoryReports.csv";
textName = "Description";
labelName = "Category";
ttdsTrain = tabularTextDatastore(filenameTrain,'SelectedVariableNames',[textName labelName]);

预览数据存储。

ttdsTrain.ReadSize = 8;
preview(ttdsTrain)
ans=8×2 table
                                  Description                                         Category       
    _______________________________________________________________________    ______________________

    {'Items are occasionally getting stuck in the scanner spools.'        }    {'Mechanical Failure'}
    {'Loud rattling and banging sounds are coming from assembler pistons.'}    {'Mechanical Failure'}
    {'There are cuts to the power when starting the plant.'               }    {'Electronic Failure'}
    {'Fried capacitors in the assembler.'                                 }    {'Electronic Failure'}
    {'Mixer tripped the fuses.'                                           }    {'Electronic Failure'}
    {'Burst pipe in the constructing agent is spraying coolant.'          }    {'Leak'              }
    {'A fuse is blown in the mixer.'                                      }    {'Electronic Failure'}
    {'Things continue to tumble off of the belt.'                         }    {'Mechanical Failure'}

创建一个自定义变换函数,该函数将从数据存储中读取的数据转换为包含预测变量和响应的表。在示例末尾列出的 transformTextData 函数接受从 tabularTextDatastore 对象读取的数据,并返回包含预测变量和响应的表。预测变量是由单词嵌入 emb 给出的 1×sequenceLength×C 单词向量数组,其中 C 是嵌入维度。这些响应是 classNames 中的类的分类标签。

使用在示例末尾列出的 readLabels 函数从训练数据中读取标签,并找出具有唯一性的类名。

labels = readLabels(ttdsTrain,labelName);
classNames = unique(labels);
numObservations = numel(labels);

使用 transformTextData 函数变换数据存储,并将序列长度指定为 14。

sequenceLength = 14;
tdsTrain = transform(ttdsTrain, @(data) transformTextData(data,sequenceLength,emb,classNames))
tdsTrain = 
  TransformedDatastore with properties:

       UnderlyingDatastore: [1×1 matlab.io.datastore.TabularTextDatastore]
    SupportedOutputFormats: ["txt"    "csv"    "xlsx"    "xls"    "parquet"    "parq"    "png"    "jpg"    "jpeg"    "tif"    "tiff"    "wav"    "flac"    "ogg"    "mp4"    "m4a"]
                Transforms: {@(data)transformTextData(data,sequenceLength,emb,classNames)}
               IncludeInfo: 0

预览变换后的数据存储。预测变量是 1×S×C 数组,其中 S 是序列长度,C 是特征数(嵌入维度)。响应是分类标签。

preview(tdsTrain)
ans=8×2 table
       Predictors            Responses     
    _________________    __________________

    {1×14×300 single}    Mechanical Failure
    {1×14×300 single}    Mechanical Failure
    {1×14×300 single}    Electronic Failure
    {1×14×300 single}    Electronic Failure
    {1×14×300 single}    Electronic Failure
    {1×14×300 single}    Leak              
    {1×14×300 single}    Electronic Failure
    {1×14×300 single}    Mechanical Failure

定义网络架构

为分类任务定义网络架构。

以下步骤说明如何定义网络架构。

  • 指定 1×S×C 的输入大小,其中 S 是序列长度,C 是特征数(嵌入维度)。

  • 对于 n 元分词长度 2、3、4 和 5,创建包含卷积层、批量归一化层、ReLU 层、丢弃层和最大池化层的层块。

  • 对于每个块,指定 200 个大小为 1×N 的卷积滤波器和大小为 1×S 的池化区域,其中 N 是 n 元分词长度。

  • 将输入层连接到每个块,并使用深度连接层串联各块的输出。

  • 要对输出进行分类,请包括一个输出大小为 K 的全连接层、一个 softmax 层和一个分类层,其中 K 是类的数量。

首先,在一个层数组中,指定输入层、首个一元分词块、深度串联层、全连接层、softmax 层和分类层。

numFeatures = emb.Dimension;
inputSize = [1 sequenceLength numFeatures];
numFilters = 200;

ngramLengths = [2 3 4 5];
numBlocks = numel(ngramLengths);

numClasses = numel(classNames);

创建一个包含输入层的层次图。将归一化选项设置为 'none',层名称设置为 'input'

layer = imageInputLayer(inputSize,'Normalization','none','Name','input');
lgraph = layerGraph(layer);

对于每个 n 元分词长度,创建一个由卷积层、批量归一化层、ReLU 层、丢弃层和最大池化层构成的块。将每个块连接到输入层。

for j = 1:numBlocks
    N = ngramLengths(j);
    
    block = [
        convolution2dLayer([1 N],numFilters,'Name',"conv"+N,'Padding','same')
        batchNormalizationLayer('Name',"bn"+N)
        reluLayer('Name',"relu"+N)
        dropoutLayer(0.2,'Name',"drop"+N)
        maxPooling2dLayer([1 sequenceLength],'Name',"max"+N)];
    
    lgraph = addLayers(lgraph,block);
    lgraph = connectLayers(lgraph,'input',"conv"+N);
end

在图中查看网络架构。

figure
plot(lgraph)
title("Network Architecture")

添加深度串联层、全连接层、softmax 层和分类层。

layers = [
    depthConcatenationLayer(numBlocks,'Name','depth')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','soft')
    classificationLayer('Name','classification')];

lgraph = addLayers(lgraph,layers);

figure
plot(lgraph)
title("Network Architecture")

将最大池化层连接到深度串联层,并在图中查看最终网络架构。

for j = 1:numBlocks
    N = ngramLengths(j);
    lgraph = connectLayers(lgraph,"max"+N,"depth/in"+j);
end

figure
plot(lgraph)
title("Network Architecture")

训练网络

指定训练选项:

  • 使用小批量大小 128 进行训练。

  • 不要打乱数据,因为数据存储不可乱序。

  • 显示训练进度图并隐藏详细输出。

miniBatchSize = 128;
numIterationsPerEpoch = floor(numObservations/miniBatchSize);

options = trainingOptions('adam', ...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','never', ...
    'Plots','training-progress', ...
    'Verbose',false);

使用 trainNetwork 函数训练网络。

net = trainNetwork(tdsTrain,lgraph,options);

使用新数据进行预测

对三个新报告的事件类型进行分类。创建包含新报告的字符串数组。

reportsNew = [ 
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

使用与预处理训练文档相同的步骤来预处理文本数据。

XNew = preprocessText(reportsNew,sequenceLength,emb);

使用经过训练的 LSTM 网络对新序列进行分类。

labelsNew = classify(net,XNew)
labelsNew = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

读取标签函数

readLabels 函数创建 tabularTextDatastore 对象 ttds 的一个副本,并读取 labelName 列中的标签。

function labels = readLabels(ttds,labelName)

ttdsNew = copy(ttds);
ttdsNew.SelectedVariableNames = labelName;
tbl = readall(ttdsNew);
labels = tbl.(labelName);

end

变换文本数据函数

transformTextData 函数获取从 tabularTextDatastore 对象读取的数据,并返回包含预测变量和响应的表。预测变量是由单词嵌入 emb 给出的 1×sequenceLength×C 单词向量数组,其中 C 是嵌入维度。这些响应是 classNames 中的类的分类标签。

function dataTransformed = transformTextData(data,sequenceLength,emb,classNames)

% Preprocess documents.
textData = data{:,1};

% Prepocess text
dataTransformed = preprocessText(textData,sequenceLength,emb);

% Read labels.
labels = data{:,2};
responses = categorical(labels,classNames);

% Convert data to table.
dataTransformed.Responses = responses;

end

预处理文本函数

preprocessTextData 函数获取文本数据、序列长度和单词嵌入,并执行以下步骤:

  1. 对文本进行分词。

  2. 将文本转换为小写。

  3. 使用嵌入将文档转换为指定长度的单词向量序列。

  4. 重构单词向量序列以输入到网络中。

function tbl = preprocessText(textData,sequenceLength,emb)

documents = tokenizedDocument(textData);
documents = lower(documents);

% Convert documents to embeddingDimension-by-sequenceLength-by-1 images.
predictors = doc2sequence(emb,documents,'Length',sequenceLength);

% Reshape images to be of size 1-by-sequenceLength-embeddingDimension.
predictors = cellfun(@(X) permute(X,[3 2 1]),predictors,'UniformOutput',false);

tbl = table;
tbl.Predictors = predictors;

end

另请参阅

| | | | | | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox)

相关主题