Main Content

使用序列数据的自定义小批量数据存储来训练网络

此示例说明如何使用自定义小批量数据存储基于无法放入内存的序列数据来训练深度学习网络。

小批量数据存储是支持批量读取数据的数据存储实现。使用小批量数据存储可读取无法放入内存的数据,或者在读取批量数据时执行特定的预处理操作。您可以使用小批量数据存储作为深度学习应用程序的训练数据集、验证数据集、测试数据集以及预测数据集的源。

此示例使用自定义小批量数据存储 sequenceDatastore,它作为支持文件包含在此示例中。您可以通过自定义数据存储函数来调整此数据存储以适应您的数据。有关说明如何创建您自己的自定义小批量数据存储的示例,请参阅Develop Custom Mini-Batch Datastore

加载训练数据

按照 [1] 和 [2] 中的说明加载日语元音数据集。zip 文件 japaneseVowels.zip 包含不同长度的序列。这些序列分成两个文件夹 TrainTest,分别包含训练序列和测试序列。在每个文件夹中,序列又分成编号从 19 的子文件夹。这些子文件夹的名称是标签名称。一个 MAT 文件表示一个序列。每个序列均为一个包含 12 行(每个特征占一行)和不同列数(每个时间步占一列)的矩阵。行数是序列的维度,列数是序列的长度。

解压缩序列数据。

filename = "japaneseVowels.zip";
outputFolder = fullfile(tempdir,"japaneseVowels");
unzip(filename,outputFolder);

创建自定义小批量数据存储

创建一个自定义小批量数据存储。小批量数据存储 sequenceDatastore 从文件夹中读取数据,并从子文件夹名称中获取标签。

使用 sequenceDatastore 创建包含序列数据的数据存储。

folderTrain = fullfile(outputFolder,"Train");
dsTrain = sequenceDatastore(folderTrain)
dsTrain = 
  sequenceDatastore with properties:

            Datastore: [1×1 matlab.io.datastore.FileDatastore]
               Labels: [270×1 categorical]
           NumClasses: 9
    SequenceDimension: 12
        MiniBatchSize: 128
      NumObservations: 270

定义 LSTM 网络架构

定义 LSTM 网络架构。将输入数据的序列维度指定为输入大小。指定具有 100 个隐含单元的 LSTM 层,并输出序列的最后一个元素。最后,指定一个输出大小等于类数的全连接层,后接一个 softmax 层。

inputSize = dsTrain.SequenceDimension;
numClasses = dsTrain.NumClasses;
numHiddenUnits = 100;
layers = [
    sequenceInputLayer(inputSize)
    lstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer];

指定训练选项。将 adam 指定为求解器,并将 GradientThreshold 指定为 1。将小批量大小设置为 27,并将最大训练轮数设置为 75。为确保数据存储创建的小批量的大小是 trainnet 函数所需的大小,还应将数据存储的小批量大小设置为相同的值。

由于小批量数据存储较小且序列较短,因此更适合在 CPU 上训练。将 ExecutionEnvironment 设置为 cpu。要在 GPU(如果可用)上进行训练,请将 ExecutionEnvironment 设置为 auto(默认值)。

miniBatchSize = 27;
options = trainingOptions("adam", ...
    InputDataFormats="CTB", ...
    Metrics="accuracy", ...
    ExecutionEnvironment="cpu", ...
    MaxEpochs=40, ...
    MiniBatchSize=miniBatchSize, ...
    GradientThreshold=1, ...
    Verbose=false, ...
    Plots="training-progress");
dsTrain.MiniBatchSize = miniBatchSize;

使用指定的训练选项训练 LSTM 网络。对于分类,使用交叉熵损失。

net = trainnet(dsTrain,layers,"crossentropy",options);

测试网络

根据测试数据创建一个序列数据存储。

folderTest = fullfile(outputFolder,"Test");
dsTest = sequenceDatastore(folderTest);

对测试数据进行分类。要使用多个观测值进行预测,请使用 minibatchpredict 函数。minibatchpredict 函数自动使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。有关受支持设备的信息,请参阅 GPU 计算要求。否则,该函数使用 CPU。要指定执行环境,请使用 ExecutionEnvironment 选项。指定与训练数据相同的小批量大小。为确保数据存储创建的小批量的大小是 minibatchpredict 函数所需的大小,还应将数据存储的小批量大小设置为相同的值。

dsTest.MiniBatchSize = miniBatchSize;
YPred = minibatchpredict(net,dsTest,InputDataFormats="CTB",MiniBatchSize=miniBatchSize);

计算预测值的分类准确度。要将预测分数转换为标签,请使用 scores2label 函数。

YTest = dsTest.Labels;
YPred = scores2label(YPred, categories(YTest));
acc = sum(YPred == YTest)./numel(YTest)
acc = 0.9486

参考

[1] Kudo, M., J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pp. 1103–1111.

[2] Kudo, M., J. Toyama, and M. Shimbo. Japanese Vowels Data Set. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

另请参阅

| | | |

相关主题