Main Content

本页翻译不是最新的。点击此处可查看最新英文版本。

使用序列数据的自定义小批量数据存储来训练网络

此示例说明如何使用自定义小批量数据存储基于无法放入内存的序列数据来训练深度学习网络。

小批量数据存储是支持批量读取数据的数据存储实现。使用小批量数据存储可读取无法放入内存的数据,或者在读取批量数据时执行特定的预处理操作。您可以使用小批量数据存储作为深度学习应用程序的训练数据集、验证数据集、测试数据集以及预测数据集的源。

此示例使用自定义小批量数据存储 sequenceDatastore,它作为支持文件包含在此示例中。您可以通过自定义数据存储函数来调整此数据存储以适应您的数据。有关说明如何创建您自己的自定义小批量数据存储的示例,请参阅Develop Custom Mini-Batch Datastore

加载训练数据

按照 [1] 和 [2] 中的说明加载日语元音数据集。zip 文件 japaneseVowels.zip 包含不同长度的序列。这些序列分成两个文件夹 TrainTest,分别包含训练序列和测试序列。在每个文件夹中,序列又分成编号从 19 的子文件夹。这些子文件夹的名称是标签名称。一个 MAT 文件表示一个序列。每个序列均为一个包含 12 行(每个特征占一行)和不同列数(每个时间步占一列)的矩阵。行数是序列的维度,列数是序列的长度。

解压缩序列数据。

filename = "japaneseVowels.zip";
outputFolder = fullfile(tempdir,"japaneseVowels");
unzip(filename,outputFolder);

创建自定义小批量数据存储

创建一个自定义小批量数据存储。小批量数据存储 sequenceDatastore 从文件夹中读取数据,并从子文件夹名称中获取标签。

使用 sequenceDatastore 创建包含序列数据的数据存储。

folderTrain = fullfile(outputFolder,"Train");
dsTrain = sequenceDatastore(folderTrain)
dsTrain = 
  sequenceDatastore with properties:

            Datastore: [1×1 matlab.io.datastore.FileDatastore]
               Labels: [270×1 categorical]
           NumClasses: 9
    SequenceDimension: 12
        MiniBatchSize: 128
      NumObservations: 270

定义 LSTM 网络架构

定义 LSTM 网络架构。将输入数据的序列维度指定为输入大小。指定具有 100 个隐含单元的 LSTM 层,并输出序列的最后一个元素。最后,指定一个输出大小等于类数的全连接层,后接一个 softmax 层和一个分类层。

inputSize = dsTrain.SequenceDimension;
numClasses = dsTrain.NumClasses;
numHiddenUnits = 100;
layers = [
    sequenceInputLayer(inputSize)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

指定训练选项。将 'adam' 指定为求解器,并将 'GradientThreshold' 指定为 1。将小批量大小设置为 27,并将最大训练轮数设置为 75。为确保数据存储创建的小批量的大小是 trainNetwork 函数所需的大小,还应将数据存储的小批量大小设置为相同的值。

由于小批量数据存储较小且序列较短,因此更适合在 CPU 上训练。将 'ExecutionEnvironment' 设置为 'cpu'。要在 GPU(如果可用)上进行训练,请将 'ExecutionEnvironment' 设置为 'auto'(默认值)。

miniBatchSize = 27;
options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',75, ...
    'MiniBatchSize',miniBatchSize, ...
    'GradientThreshold',1, ...
    'Verbose',0, ...
    'Plots','training-progress');
dsTrain.MiniBatchSize = miniBatchSize;

使用指定的训练选项训练 LSTM 网络。

net = trainNetwork(dsTrain,layers,options);

Figure Training Progress (10-Oct-2022 16:51:45) contains 2 axes objects and another object of type uigridlayout. Axes object 1 contains 3 objects of type line. Axes object 2 contains 3 objects of type line.

测试网络

根据测试数据创建一个序列数据存储。

folderTest = fullfile(outputFolder,"Test");
dsTest = sequenceDatastore(folderTest);

对测试数据进行分类。指定与训练数据相同的小批量大小。为确保数据存储创建的小批量的大小是 classify 函数所需的大小,还应将数据存储的小批量大小设置为相同的值。

dsTest.MiniBatchSize = miniBatchSize;
YPred = classify(net,dsTest,'MiniBatchSize',miniBatchSize);

计算预测值的分类准确度。

YTest = dsTest.Labels;
acc = sum(YPred == YTest)./numel(YTest)
acc = 0.9243

参考

[1] Kudo, M., J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pp. 1103–1111.

[2] Kudo, M., J. Toyama, and M. Shimbo. Japanese Vowels Data Set. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

另请参阅

| | |

相关主题