Main Content

《傲慢与偏见》与 MATLAB

此示例说明如何训练深度学习 LSTM 网络来通过字符嵌入生成文本。

要训练深度学习网络以生成文本,请训练“序列到序列”的 LSTM 网络,以预测字符序列中的下一个字符。要训练网络以预测下一个字符,请将响应指定为移位一个时间步的输入序列。

要使用字符嵌入,请将每个训练观测值转换为整数序列,其中的整数对字符词汇表进行索引。在网络中包含一个单词嵌入层,该层学习字符的嵌入并将整数映射到向量。

加载训练数据

读取《傲慢与偏见》(简·奥斯汀著)的古腾堡计划电子书的 HTML 代码,并使用 webreadhtmlTree 对其进行解析。

url = "https://www.gutenberg.org/files/1342/1342-h/1342-h.htm";
code = webread(url);
tree = htmlTree(code);

通过查找 p 元素来提取段落。使用 CSS 选择器 ':not(.toc)' 指定忽略带有 "toc" 类的段落元素。

paragraphs = findElement(tree,'p:not(.toc)');

使用 extractHTMLText 从段落中提取文本数据,并删除空字符串。

textData = extractHTMLText(paragraphs);
textData(textData == "") = [];

删除短于 20 个字符的字符串。

idx = strlength(textData) < 20;
textData(idx) = [];

用文字云可视化文本数据。

figure
wordcloud(textData);
title("Pride and Prejudice")

将文本数据转换为序列

将文本数据转换为预测变量的字符索引序列和响应的分类序列。

分类函数将换行符和空白字符条目视为未定义。要为这些字符创建分类元素,请分别使用特殊字符 ""(段落符号,"\x00B6")和"·"(间隔号,"\x00B7")替换它们。为防止出现歧义,您必须选择文本中未出现的特殊字符。这些字符未出现在训练数据中,因此可用于此目的。

newlineCharacter = compose("\x00B6");
whitespaceCharacter = compose("\x00B7");
textData = replace(textData,[newline " "],[newlineCharacter whitespaceCharacter]);

遍历文本数据,并创建表示每个观测值字符的字符索引序列以及响应的字符分类序列。要表示每个观测值的结束,请包含特殊字符“␃”(文本结尾,"\x2403")。

endOfTextCharacter = compose("\x2403");
numDocuments = numel(textData);
for i = 1:numDocuments
    characters = textData{i};
    X = double(characters);
    
    % Create vector of categorical responses with end of text character.
    charactersShifted = [cellstr(characters(2:end)')' endOfTextCharacter];
    Y = categorical(charactersShifted);
    
    XTrain{i} = X;
    YTrain{i} = Y;
end

在训练过程中,默认情况下,软件将训练数据拆分成小批量并填充序列,使它们具有相同的长度。过多填充会对网络性能产生负面影响。

为了防止训练过程添加过多填充,您可以按序列长度对训练数据进行排序,并选择合适的小批量大小,以使同一小批量中的序列长度相近。

获取每个观测值的序列长度。

numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end

按序列长度对数据进行排序。

[~,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);

创建和训练 LSTM 网络

定义 LSTM 架构。指定一个“序列到序列”LSTM 分类网络,其中包含 400 个隐含单元。将输入大小设置为训练数据的特征维度。对于字符索引序列,特征维度为 1。指定维度为 200 的单词嵌入层,并指定单词数(对应于字符数)为输入数据中的最高字符值。将全连接层的输出大小设置为响应中的类别数。为帮助防止过拟合,在 LSTM 层后面包含一个丢弃层。

单词嵌入层学习字符嵌入并将每个字符映射到一个 200 维向量。

inputSize = size(XTrain{1},1);
numClasses = numel(categories([YTrain{:}]));
numCharacters = max([textData{:}]);

layers = [
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(200,numCharacters)
    lstmLayer(400,'OutputMode','sequence')
    dropoutLayer(0.2);
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

指定训练选项。指定以小批量大小 32 和初始学习率 0.01 进行训练。要防止梯度爆炸,请将梯度阈值设置为 1。要确保数据保持排序,请将 'Shuffle' 设置为 'never'。要监控训练进度,请将 'Plots' 选项设置为 'training-progress'。要隐藏详尽输出,请将 'Verbose' 设置为 false

options = trainingOptions('adam', ...
    'MiniBatchSize',32,...
    'InitialLearnRate',0.01, ...
    'GradientThreshold',1, ...
    'Shuffle','never', ...
    'Plots','training-progress', ...
    'Verbose',false);

训练网络。

net = trainNetwork(XTrain,YTrain,layers,options);

生成新文本

根据训练数据中文本的所有首字符的概率分布抽取一个字符来生成文本的第一个字符。接着使用经过训练的 LSTM 网络基于当前已生成的文本序列预测下一序列,以生成其余字符。继续逐个生成字符,直到网络预测到“文本结尾”字符。

根据训练数据中所有首字符的分布抽取第一个字符。

initialCharacters = extractBefore(textData,2);
firstCharacter = datasample(initialCharacters,1);
generatedText = firstCharacter;

将第一个字符转换为数值索引。

X = double(char(firstCharacter));

在后续的预测中,会根据网络的预测分数对来抽取下一个字符。预测分数表示下一个字符的概率分布。从网络输出层的类名给出的字符词汇表中抽取字符。从网络的分类层获取词汇表。

vocabulary = string(net.Layers(end).ClassNames);

使用 predictAndUpdateState 逐个字符进行预测。对于每次预测,都输入前一个字符的索引。当网络预测到文本结尾字符或生成的文本长度达到 500 个字符时,停止预测。对于大型数据集合、长序列或大型网络,在 GPU 上进行预测计算通常比在 CPU 上快。其他情况下,在 CPU 上进行预测计算通常更快。对于单时间步预测,请使用 CPU。要使用 CPU 进行预测,请将 predictAndUpdateState'ExecutionEnvironment' 选项设置为 'cpu'

maxLength = 500;
while strlength(generatedText) < maxLength
    % Predict the next character scores.
    [net,characterScores] = predictAndUpdateState(net,X,'ExecutionEnvironment','cpu');
    
    % Sample the next character.
    newCharacter = datasample(vocabulary,1,'Weights',characterScores);
    
    % Stop predicting at the end of text.
    if newCharacter == endOfTextCharacter
        break
    end
    
    % Add the character to the generated text.
    generatedText = generatedText + newCharacter;
    
    % Get the numeric index of the character.
    X = double(char(newCharacter));
end

通过将特殊字符替换为对应的空白字符和换行符,重新构造生成的文本。

generatedText = replace(generatedText,[newlineCharacter whitespaceCharacter],[newline " "])
generatedText = 
"“I wish Mr. Darcy, upon latter of my sort sincerely fixed in the regard to relanth. We were to join on the Lucases. They are married with him way Sir Wickham, for the possibility which this two od since to know him one to do now thing, and the opportunity terms as they, and when I read; nor Lizzy, who thoughts of the scent; for a look for times, I never went to the advantage of the case; had forcibling himself. They pility and lively believe she was to treat off in situation because, I am exceal"

要生成多篇文本,请在每次生成完成后使用 resetState 重置网络状态。

net = resetState(net);

另请参阅

(Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | | | | | | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox)

相关主题