《傲慢与偏见》与 MATLAB
此示例说明如何训练深度学习 LSTM 网络来通过字符嵌入生成文本。
要训练深度学习网络以生成文本,请训练“序列到序列”的 LSTM 网络,以预测字符序列中的下一个字符。要训练网络以预测下一个字符,请将响应指定为移位一个时间步的输入序列。
要使用字符嵌入,请将每个训练观测值转换为整数序列,其中的整数对字符词汇表进行索引。在网络中包含一个单词嵌入层,该层学习字符的嵌入并将整数映射到向量。
加载训练数据
读取《傲慢与偏见》(简·奥斯汀著)的古腾堡计划电子书的 HTML 代码,并使用 webread 和 htmlTree (Text Analytics Toolbox) 对其进行解析。
url = "https://www.gutenberg.org/files/1342/1342-h/1342-h.htm";
code = webread(url);
tree = htmlTree(code);通过查找 p 元素来提取段落。要删除与目录相关的信息,请使用 CSS 选择器 ':not(.toc)' 忽略带有 "toc" 类的段落元素。
paragraphs = findElement(tree,'p:not(.toc)');使用 extractHTMLText (Text Analytics Toolbox) 函数从段落中提取文本数据。删除空字符串。
textData = extractHTMLText(paragraphs);
textData(textData == "") = [];删除短于 20 个字符的字符串。
idx = strlength(textData) < 20; textData(idx) = [];
用文字云可视化文本数据。
figure
wordcloud(textData);
title("Pride and Prejudice")
将文本数据转换为序列
将文本数据转换为预测变量的字符索引序列和响应的分类序列。
categorical函数将换行符和空白字符条目视为未定义。要为这些字符创建分类元素,请分别使用特殊字符 "¶"(段落符号,"\x00B6")和"·"(间隔号,"\x00B7")替换它们。为防止出现歧义,请选择文本中未出现的特殊字符。
newlineCharacter = compose("\x00B6"); whitespaceCharacter = compose("\x00B7"); textData = replace(textData,[newline " "],[newlineCharacter whitespaceCharacter]);
遍历文本数据,并创建表示每个观测值字符的字符索引序列以及响应的字符分类序列。要表示每个观测值的结束,请包含特殊字符“␃”(文本结尾,"\x2403")。将分类数组的类别指定为文本数据中出现的所有字符。
endOfTextCharacter = compose("\x2403"); numDocuments = numel(textData); uniqueCharacters = unique([textData{:}]); for i = 1:numDocuments characters = textData{i}; X = double(characters); % Create vector of categorical responses with end of text character. charactersShifted = [cellstr(characters(2:end)')' endOfTextCharacter]; Y = categorical(charactersShifted, [string(uniqueCharacters'); endOfTextCharacter]); XTrain{i} = X; YTrain{i} = Y; end
在训练过程中,默认情况下,软件将训练数据拆分成小批量并填充序列,使它们具有相同的长度。过多填充会对网络性能产生负面影响。
为了防止训练过程添加过多填充,您可以按序列长度对训练数据进行排序。这样做意味着任何给定小批量中的序列具有相似的长度。
numObservations = numel(XTrain); for i=1:numObservations sequence = XTrain{i}; sequenceLengths(i) = size(sequence,2); end [~,idx] = sort(sequenceLengths); XTrain = XTrain(idx); YTrain = YTrain(idx);
创建和训练 LSTM 网络
定义 LSTM 架构。指定一个“序列到序列”LSTM 分类网络,其中包含 400 个隐藏单元。将输入大小设置为训练数据的特征维度。对于字符索引序列,特征维度为 1。指定维度为 200 的单词嵌入层,并指定单词数(对应于字符数)为输入数据中的最高字符值。将全连接层的输出大小设置为响应中的类别数。为帮助防止过拟合,在 LSTM 层后面包含一个丢弃层。
单词嵌入层学习字符嵌入并将每个字符映射到一个 200 维向量。
inputSize = size(XTrain{1},1);
classes = categories([YTrain{:}]);
numClasses = numel(classes);
numCharacters = max([textData{:}]);
layers = [
sequenceInputLayer(inputSize)
wordEmbeddingLayer(200,numCharacters)
lstmLayer(400,OutputMode="sequence")
dropoutLayer(0.2);
fullyConnectedLayer(numClasses)
softmaxLayer];指定训练选项。指定以小批量大小 32 和初始学习率 0.01 进行训练。要防止梯度爆炸,请将梯度阈值设置为 1。要确保数据保持排序,请将 Shuffle 设置为 "never"。要监控训练进度,请将 Plots 选项设置为 "training-progress"。要隐藏详尽输出,请将 Verbose 设置为 false。由于训练数据具有行和列分别对应于通道和时间步的序列,请指定输入和目标数据格式 "CTB"(通道、时间、批量)。
options = trainingOptions("adam", ... InputDataFormats = "CTB", ... TargetDataFormats = "CTB", ... Metrics = "accuracy", ... MiniBatchSize = 32,... InitialLearnRate = 0.01, ... GradientThreshold = 0.1, ... Shuffle = "never", ... Plots = "training-progress", ... Verbose = false, ... ExecutionEnvironment = "auto");
使用 trainnet 函数训练网络。
net = trainnet(XTrain,YTrain,layers,"crossentropy",options);
生成新文本
根据训练数据中文本的所有首字符的概率分布抽取一个字符来生成文本的第一个字符。接着使用经过训练的 LSTM 网络基于当前已生成的文本序列预测下一序列,以生成其余字符。继续逐个生成字符,直到网络预测到“文本结尾”字符。
根据训练数据中所有首字符的分布抽取第一个字符。
initialCharacters = extractBefore(textData,2); firstCharacter = datasample(initialCharacters,1); generatedText = firstCharacter;
将第一个字符转换为数值索引。
X = double(char(firstCharacter));
在后续的预测中,会根据网络的预测分数对来抽取下一个字符。预测分数表示下一个字符的概率分布。从网络输出层的类名给出的字符词汇表中抽取字符。从训练数据获取词汇表。
vocabulary = string(classes);
使用 predict 逐个字符进行预测。对于每次预测,都输入前一个字符的索引。当网络预测到文本结尾字符或生成的文本长度达到 500 个字符时,停止预测。
maxLength = 500; while strlength(generatedText) < maxLength % Predict the next character scores and output the network state. [characterScores,state] = predict(net,X); % Update the state. net.State = state; % Sample the next character. newCharacter = datasample(vocabulary,1,Weights=gather(characterScores)); % Stop predicting at the end of text. if newCharacter == endOfTextCharacter break end % Add the character to the generated text. generatedText = generatedText + newCharacter; % Get the numeric index of the character. X = double(char(newCharacter)); end
通过将特殊字符替换为对应的空白字符和换行符,重新构造生成的文本。
generatedText = replace(generatedText,[newlineCharacter whitespaceCharacter],[newline " "])generatedText = "“You dread, in must at Mr. Darcy’s more she had coldured with, to pubid since mistaken part of their return out general futual limentable or town Mact last ledge of the whire of these attended more longer impostible to looke able. They need marely bedre. Be all looking enough to this vortimily before, shook back beauty, she could inte; and moresning to it parton, that which she had out also on Jane, though to Loss arain Miss depenting, as mightmon that chuel to Mr. Darcy’s much man’s po rithment"
要生成多篇文本,请在每次生成完成后使用 resetState 重置网络状态。
net = resetState(net);
另请参阅
wordEmbeddingLayer (Text Analytics Toolbox) | doc2sequence (Text Analytics Toolbox) | tokenizedDocument (Text Analytics Toolbox) | lstmLayer | trainnet | trainingOptions | dlnetwork | sequenceInputLayer | wordcloud (Text Analytics Toolbox) | extractHTMLText (Text Analytics Toolbox) | findElement (Text Analytics Toolbox) | htmlTree (Text Analytics Toolbox)
主题
- 使用深度学习生成文本
- Word-by-Word Text Generation Using Deep Learning (Text Analytics Toolbox)
- Create Simple Text Model for Classification (Text Analytics Toolbox)
- Analyze Text Data Using Topic Models (Text Analytics Toolbox)
- Analyze Text Data Using Multiword Phrases (Text Analytics Toolbox)
- Train a Sentiment Classifier (Text Analytics Toolbox)
- 使用深度学习进行序列分类
- 在 MATLAB 中进行深度学习