使用序列数据的自定义小批量数据存储来训练网络
此示例说明如何使用自定义小批量数据存储基于无法放入内存的序列数据来训练深度学习网络。
小批量数据存储是支持批量读取数据的数据存储实现。使用小批量数据存储可读取无法放入内存的数据,或者在读取批量数据时执行特定的预处理操作。您可以使用小批量数据存储作为深度学习应用程序的训练数据集、验证数据集、测试数据集以及预测数据集的源。
此示例使用自定义小批量数据存储 sequenceDatastore,它作为支持文件包含在此示例中。您可以通过自定义数据存储函数来调整此数据存储以适应您的数据。有关说明如何创建您自己的自定义小批量数据存储的示例,请参阅Develop Custom Mini-Batch Datastore。
加载训练数据
按照 [1] 和 [2] 中的说明加载日语元音数据集。zip 文件 japaneseVowels.zip 包含不同长度的序列。这些序列分成两个文件夹 Train 和 Test,分别包含训练序列和测试序列。在每个文件夹中,序列又分成编号从 1 到 9 的子文件夹。这些子文件夹的名称是标签名称。一个 MAT 文件表示一个序列。每个序列均为一个包含 12 行(每个特征占一行)和不同列数(每个时间步占一列)的矩阵。行数是序列的维度,列数是序列的长度。
解压缩序列数据。
filename = "japaneseVowels.zip"; outputFolder = fullfile(tempdir,"japaneseVowels"); unzip(filename,outputFolder);
创建自定义小批量数据存储
创建一个自定义小批量数据存储。小批量数据存储 sequenceDatastore 从文件夹中读取数据,并从子文件夹名称中获取标签。
使用 sequenceDatastore 创建包含序列数据的数据存储。
folderTrain = fullfile(outputFolder,"Train");
dsTrain = sequenceDatastore(folderTrain)dsTrain =
sequenceDatastore with properties:
Datastore: [1×1 matlab.io.datastore.FileDatastore]
Labels: [270×1 categorical]
NumClasses: 9
SequenceDimension: 12
MiniBatchSize: 128
NumObservations: 270
定义 LSTM 网络架构
定义 LSTM 网络架构。将输入数据的序列维度指定为输入大小。指定具有 100 个隐含单元的 LSTM 层,并输出序列的最后一个元素。最后,指定一个输出大小等于类数的全连接层,后接一个 softmax 层。
inputSize = dsTrain.SequenceDimension;
numClasses = dsTrain.NumClasses;
numHiddenUnits = 100;
layers = [
sequenceInputLayer(inputSize)
lstmLayer(numHiddenUnits,OutputMode="last")
fullyConnectedLayer(numClasses)
softmaxLayer];指定训练选项。将 adam 指定为求解器,并将 GradientThreshold 指定为 1。将小批量大小设置为 27,并将最大训练轮数设置为 75。为确保数据存储创建的小批量的大小是 trainnet 函数所需的大小,还应将数据存储的小批量大小设置为相同的值。
由于小批量数据存储较小且序列较短,因此更适合在 CPU 上训练。将 ExecutionEnvironment 设置为 cpu。要在 GPU(如果可用)上进行训练,请将 ExecutionEnvironment 设置为 "auto"(默认值)。
miniBatchSize = 27; options = trainingOptions("adam", ... InputDataFormats="CTB", ... Metrics="accuracy", ... ExecutionEnvironment="cpu", ... MaxEpochs=40, ... MiniBatchSize=miniBatchSize, ... GradientThreshold=1, ... Verbose=false, ... Plots="training-progress"); dsTrain.MiniBatchSize = miniBatchSize;
使用 trainnet 函数训练神经网络。对于分类,使用交叉熵损失。
net = trainnet(dsTrain,layers,"crossentropy",options);
测试网络
根据测试数据创建一个序列数据存储。指定与训练数据相同的小批量大小。
folderTest = fullfile(outputFolder,"Test");
dsTest = sequenceDatastore(folderTest);
dsTest.MiniBatchSize = miniBatchSize;使用 testnet 函数测试神经网络。对于单标签分类,需评估准确度。准确度是指正确预测的百分比。默认情况下,testnet 函数使用 GPU(如果有)。要手动选择执行环境,请使用 testnet 函数的 ExecutionEnvironment 参量。要确保该函数使用大小和格式与训练相同的小批量,请将 ExecutionEnvironment、InputDataFormats 和 MiniBatchSize 参量设置为与训练相同的值。
accuracy = testnet(net,dsTest,"accuracy", ... ExecutionEnvironment="cpu", ... InputDataFormats="CTB", ... MiniBatchSize=miniBatchSize)
accuracy = 92.4324
参考
[1] Kudo, M., J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pp. 1103–1111.
[2] Kudo, M., J. Toyama, and M. Shimbo. Japanese Vowels Data Set. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
另请参阅
trainnet | trainingOptions | dlnetwork | lstmLayer | sequenceInputLayer