Main Content

本页的翻译已过时。点击此处可查看最新英文版本。

使用深度网络设计器创建简单的序列分类网络

此示例说明如何使用深度网络设计器创建简单的长短期记忆 (LSTM) 分类网络。

要训练深度神经网络以对序列数据进行分类,可以使用 LSTM 网络。LSTM 网络是一种循环神经网络 (RNN),可学习序列数据的时间步之间的长期依存关系。

该示例演示如何:

  • 加载序列数据。

  • 构造网络架构。

  • 指定训练选项。

  • 训练网络。

  • 预测新数据的标签并计算分类准确度。

加载数据

按照 [1][2] 中的说明加载日语元音数据集。预测变量是包含不同长度序列的元胞数组,特征维度为 12。标签是由标签 1、2、...、9 组成的分类向量。

[XTrain,YTrain] = japaneseVowelsTrainData;
[XValidation,YValidation] = japaneseVowelsTestData;

查看前几个训练序列的大小。序列是具有 12 行(每个特征一行)和不同列数(每个时间步一列)的矩阵。

XTrain(1:5)
ans=5×1 cell array
    {12×20 double}
    {12×26 double}
    {12×22 double}
    {12×20 double}
    {12×21 double}

定义网络架构

打开深度网络设计器。

deepNetworkDesigner

序列到标签上暂停,然后点击打开。这会打开一个适合序列分类问题的预置网络。

深度网络设计器显示该预置网络。

您可以轻松地将此序列网络用于日语元音字母数据集。

选择 sequenceInputLayer,检查并确认 InputSize 设置为 12,与特征维度匹配。

选择 lstmLayer 并将 NumHiddenUnits 设置为 100。

选择 fullyConnectedLayer,检查并确认 OutputSize 设置为 9,即类的数目。

检查网络架构

要检查网络并查看层的详细信息,请点击分析

导出网络架构

要将网络架构导出到工作区,请在设计器选项卡上,点击导出。深度网络设计器将网络保存为变量 layers_1

您还可以通过选择导出 > 生成代码来生成用于构造网络架构的代码。

训练网络

指定训练选项并训练网络。

由于小批量数据存储较小且序列较短,因此更适合在 CPU 上训练。将 'ExecutionEnvironment' 设置为 'cpu'。要在 GPU(如果可用)上进行训练,请将 'ExecutionEnvironment' 设置为 'auto'(默认值)。

miniBatchSize = 27;
options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'ValidationData',{XValidation,YValidation}, ...
    'GradientThreshold',2, ...
    'Shuffle','every-epoch', ...
    'Verbose',false, ...
    'Plots','training-progress');

训练网络。

net = trainNetwork(XTrain,YTrain,layers_1,options);

测试网络

对测试数据进行分类,并计算分类准确度。指定与训练相同的小批量大小。

YPred = classify(net,XValidation,'MiniBatchSize',miniBatchSize);
acc = mean(YPred == YValidation)
acc = 0.9432

在接下来的步骤中,您可以尝试通过使用双向 LSTM (BiLSTM) 层或创建更深的网络来提高准确度。有关详细信息,请参阅长短期记忆网络

有关说明如何使用卷积网络对序列数据进行分类的示例,请参阅使用深度学习进行语音命令识别

参考资料

[1] Kudo, Mineichi, Jun Toyama, and Masaru Shimbo.“Multidimensional Curve Classification Using Passing-through Regions.”Pattern Recognition Letters 20, no. 11–13 (November 1999):1103–11. https://doi.org/10.1016/S0167-8655(99)00077-X.

[2] Kudo, Mineichi, Jun Toyama, and Masaru Shimbo.Japanese Vowels Data Set.Distributed by UCI Machine Learning Repository. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

另请参阅

相关主题