Main Content

本页翻译不是最新的。点击此处可查看最新英文版本。

使用无法放入内存的序列数据训练网络

此示例说明如何通过变换和合并数据存储基于无法放入内存的序列数据来训练深度学习网络。

变换后的数据存储对从基础数据存储读取的数据进行变换或处理。您可以使用变换后的数据存储作为深度学习应用的训练数据集、验证数据集、测试数据集以及预测数据集的数据源。使用变换后的数据存储可读取无法放入内存的数据,或者在读取批量数据时执行特定的预处理操作。当您有若干单独的包含预测变量和标签的数据存储时,您可以将它们合并起来,以便将数据输入到深度学习网络中。

在训练网络时,软件通过填充、截断或拆分输入数据来创建长度相同的小批量序列。对于内存数据,trainingOptions 函数提供了填充和截断输入序列的选项,但对于内存外的数据,您必须手动填充和截断序列。

加载训练数据

按照 [1] 和 [2] 中的说明加载日语元音数据集。zip 文件 japaneseVowels.zip 包含不同长度的序列。这些序列分成两个文件夹 TrainTest,分别包含训练序列和测试序列。在每个文件夹中,序列又分成编号从 19 的子文件夹。这些子文件夹的名称是标签名称。一个 MAT 文件表示一个序列。每个序列均为一个包含 12 行(每个特征占一行)和不同列数(每个时间步占一列)的矩阵。行数是序列的维度,列数是序列的长度。

解压缩序列数据。

filename = "japaneseVowels.zip";
outputFolder = fullfile(tempdir,"japaneseVowels");
unzip(filename,outputFolder);

对于训练预测变量,请创建一个文件数据存储,并将读取函数指定为 load 函数。load 函数将数据从 MAT 文件加载到结构体数组中。要从训练文件夹的子文件夹中读取文件,请将 'IncludeSubfolders' 选项设置为 true

folderTrain = fullfile(outputFolder,"Train");
fdsPredictorTrain = fileDatastore(folderTrain, ...
    'ReadFcn',@load, ...
    'IncludeSubfolders',true);

预览数据存储。返回的结构体包含来自第一个文件的单个序列。

preview(fdsPredictorTrain)
ans = struct with fields:
    X: [12×20 double]

对于标签,请创建一个文件数据存储,并将读取函数指定为 readLabel 函数,该函数在示例末尾定义。readLabel 函数从子文件夹名称中提取标签。

classNames = string(1:9);
fdsLabelTrain = fileDatastore(folderTrain, ...
    'ReadFcn',@(filename) readLabel(filename,classNames), ...
    'IncludeSubfolders',true);

预览数据存储。输出对应于第一个文件的标签。

preview(fdsLabelTrain)
ans = categorical
     1 

变换和合并数据存储

要将预测变量的数据存储中的序列数据输入深度学习网络,序列的小批量必须具有相同的长度。使用在数据存储末尾定义的 padSequence 函数变换数据存储,该函数填充或截断序列以使其长度为 20。

sequenceLength = 20;
tdsTrain = transform(fdsPredictorTrain,@(data) padSequence(data,sequenceLength));

预览变换后的数据存储。输出对应于来自第一个文件的填充序列。

X = preview(tdsTrain)
X = 1×1 cell array
    {12×20 double}

要将来自两个数据存储的预测变量和标签输入一个深度学习网络,请使用 combine 函数将其合并。

cdsTrain = combine(tdsTrain,fdsLabelTrain);

预览合并后的数据存储。数据存储返回一个 1×2 元胞数组。第一个元素对应于预测变量。第二个元素对应于标签。

preview(cdsTrain)
ans = 1×2 cell array
    {12×20 double}    {[1]}

定义 LSTM 网络架构

定义 LSTM 网络架构。将输入数据的特征数量指定为输入大小。指定具有 100 个隐含单元的 LSTM 层,并输出序列的最后一个元素。最后,指定一个输出大小等于类数的全连接层,后接一个 softmax 层和一个分类层。

numFeatures = 12;
numClasses = numel(classNames);
numHiddenUnits = 100;

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

指定训练选项。将求解器设置为 'adam',将 'GradientThreshold' 设置为 2。将小批量大小设置为 27,并将最大训练轮数设置为 75。数据存储不支持乱序,因此将 'Shuffle' 设置为 'never'

由于小批量数据存储较小且序列较短,因此更适合在 CPU 上训练。将 'ExecutionEnvironment' 设置为 'cpu'。要在 GPU(如果可用)上进行训练,请将 'ExecutionEnvironment' 设置为 'auto'(默认值)。

miniBatchSize = 27;

options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',75, ...
    'MiniBatchSize',miniBatchSize, ...
    'GradientThreshold',2, ...
    'Shuffle','never',...
    'Verbose',0, ...
    'Plots','training-progress');

使用指定的训练选项训练 LSTM 网络。

net = trainNetwork(cdsTrain,layers,options);

测试网络

使用与训练数据相同的步骤来创建包含保留测试数据的变换后的数据存储。

folderTest = fullfile(outputFolder,"Test");

fdsPredictorTest = fileDatastore(folderTest, ...
    'ReadFcn',@load, ...
    'IncludeSubfolders',true);
tdsTest = transform(fdsPredictorTest,@(data) padSequence(data,sequenceLength));

使用经过训练的网络对测试数据进行预测。

YPred = classify(net,tdsTest,'MiniBatchSize',miniBatchSize);

基于测试数据计算分类准确度。要获取测试集的标签,请使用读取函数 readLabel 创建一个文件数据存储,并指定包含子文件夹。通过将 'UniformRead' 选项设置为 true,指定输出可垂直串联。

fdsLabelTest = fileDatastore(folderTest, ...
    'ReadFcn',@(filename) readLabel(filename,classNames), ...
    'IncludeSubfolders',true, ...
    'UniformRead',true);
YTest = readall(fdsLabelTest);
accuracy = mean(YPred == YTest)
accuracy = 0.9351

函数

readLabel 函数根据 classNames 中的类别从指定文件名中提取标签。

function label = readLabel(filename,classNames)

filepath = fileparts(filename);
[~,label] = fileparts(filepath);

label = categorical(string(label),classNames);

end

padSequence 函数填充或截断 data.X 中的序列,使其具有指定的序列长度,并以 1×1 元胞形式返回结果。

function sequence = padSequence(data,sequenceLength)

sequence = data.X;
[C,S] = size(sequence);

if S < sequenceLength
    padding = zeros(C,sequenceLength-S);
    sequence = [sequence padding];
else
    sequence = sequence(:,1:sequenceLength);
end

sequence = {sequence};

end

另请参阅

| | | | |

相关主题