Main Content

trainNetwork

训练神经网络

说明

对于分类和回归任务,您可以使用 trainNetwork 函数来训练各种类型的神经网络。

例如,您可以训练:

  • 用于图像数据的卷积神经网络(ConvNet、CNN)

  • 用于序列和时间序列数据的循环神经网络 (RNN),例如长短期记忆 (LSTM) 或门控循环单元 (GRU) 神经网络

  • 用于数值特征数据的多层感知器 (MLP) 神经网络

您可以在 CPU 或 GPU 上训练。对于图像分类和图像回归,您可以使用多个 GPU 或者本地或远程并行池来并行训练单个神经网络。在 GPU 上训练或并行训练需要 Parallel Computing Toolbox™。要使用 GPU 进行深度学习,您还必须拥有支持的 GPU 设备。有关受支持设备的信息,请参阅GPU Computing Requirements (Parallel Computing Toolbox)要指定训练选项,包括执行环境的选项,请使用 trainingOptions 函数。

在训练神经网络时,您可以将预测变量和响应指定为单个输入或两个单独的输入。

示例

net = trainNetwork(images,layers,options) 使用 images 指定的图像和响应以及 options 定义的训练选项,针对图像分类和回归任务训练 layers 指定的神经网络。

示例

net = trainNetwork(images,responses,layers,options) 使用 images 指定的图像和 responses 指定的响应进行训练。

net = trainNetwork(sequences,layers,options) 使用 sequences 指定的序列和响应,针对序列或时间序列分类和回归任务训练神经网络(例如,LSTM 或 GRU 神经网络)。

示例

net = trainNetwork(sequences,responses,layers,options) 使用 sequences 指定的序列和 responses 指定的响应进行训练。

示例

net = trainNetwork(features,layers,options) 使用 features 指定的特征数据和响应,针对特征分类或回归任务训练神经网络(例如,多层感知器 (MLP) 神经网络)。

net = trainNetwork(features,responses,layers,options) 使用 features 指定的特征数据和 responses 指定的响应进行训练。

net = trainNetwork(mixed,layers,options) 使用 mixed 指定的数据和响应,训练一个具有混合数据类型的多个输入的神经网络。

[net,info] = trainNetwork(___) 还使用上述任一语法返回有关训练的信息。

示例

全部折叠

将数据作为 ImageDatastore 对象加载。

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet', ...
    'nndemos','nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

该数据存储包含 10,000 个数字 0 至 9 的合成图像。这些图像是通过对使用不同字体创建的数字图像应用随机变换生成的。每个数字图像为 28×28 像素。该数据存储包含的每个类别都有相同数量的图像。

显示数据存储中的部分图像。

figure
numImages = 10000;
perm = randperm(numImages,20);
for i = 1:20
    subplot(4,5,i);
    imshow(imds.Files{perm(i)});
    drawnow;
end

Figure contains 20 axes objects. Axes object 1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

划分数据存储,使训练集中的每个类别包含 750 个图像,测试集包含对应每个标签的其余图像。

numTrainingFiles = 750;
[imdsTrain,imdsTest] = splitEachLabel(imds,numTrainingFiles,'randomize');

splitEachLabeldigitData 中的图像文件拆分为两个新的数据存储,imdsTrainimdsTest

定义卷积神经网络架构。

layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(5,20)
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

将选项设置为带动量的随机梯度下降的默认设置。将最大训练轮数设置为 20,以 0.0001 的初始学习率开始训练。

options = trainingOptions('sgdm', ...
    'MaxEpochs',20,...
    'InitialLearnRate',1e-4, ...
    'Verbose',false, ...
    'Plots','training-progress');

训练网络。

net = trainNetwork(imdsTrain,layers,options);

Figure Training Progress (19-Aug-2023 11:53:04) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 6 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel Accuracy (%) contains 6 objects of type patch, text, line.

基于未用于训练网络的测试集运行经过训练的网络,并预测图像标签(数字)。

YPred = classify(net,imdsTest);
YTest = imdsTest.Labels;

计算准确度。准确度是测试数据中与来自 classify 的分类匹配的真实标签数量与测试数据中图像数量的比率。

accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9400

使用增强的图像数据训练一个卷积神经网络。数据增强有助于防止网络过拟合和记忆训练图像的具体细节。

加载由手写数字的合成图像组成的样本数据。

[XTrain,YTrain] = digitTrain4DArrayData;

digitTrain4DArrayData 将数字训练集作为四维数组数据加载。XTrain 是一个 28×28×1×5000 数组,其中:

  • 28 是图像的高度和宽度。

  • 1 是通道数。

  • 5000 是由手写数字组成的合成图像的数目。

YTrain 是包含每个观测值的标签的分类向量。

留出 1000 个图像用于网络验证。

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

创建一个 imageDataAugmenter 对象,它指定图像增强的预处理选项,如调整大小、旋转、平移和翻转。水平和垂直随机平移图像的最多三个像素,旋转图像角度最多 20 度。

imageAugmenter = imageDataAugmenter( ...
    'RandRotation',[-20,20], ...
    'RandXTranslation',[-3 3], ...
    'RandYTranslation',[-3 3])
imageAugmenter = 
  imageDataAugmenter with properties:

           FillValue: 0
     RandXReflection: 0
     RandYReflection: 0
        RandRotation: [-20 20]
           RandScale: [1 1]
          RandXScale: [1 1]
          RandYScale: [1 1]
          RandXShear: [0 0]
          RandYShear: [0 0]
    RandXTranslation: [-3 3]
    RandYTranslation: [-3 3]

创建一个用于网络训练的 augmentedImageDatastore 对象,并指定图像输出大小。在训练期间,数据存储执行图像增强并调整图像大小。数据存储会增强图像,但不会将任何图像保存到内存中。trainNetwork 会更新网络参数,然后丢弃增强的图像。

imageSize = [28 28 1];
augimds = augmentedImageDatastore(imageSize,XTrain,YTrain,'DataAugmentation',imageAugmenter);

指定卷积神经网络架构。

layers = [
    imageInputLayer(imageSize)
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

指定带动量的随机梯度下降的训练选项。

opts = trainingOptions('sgdm', ...
    'MaxEpochs',15, ...
    'Shuffle','every-epoch', ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'ValidationData',{XValidation,YValidation});

训练网络。由于验证图像未进行增强,因此验证准确度高于训练准确度。

net = trainNetwork(augimds,layers,opts);

加载由手写数字的合成图像组成的样本数据。第三个输出包含每个图像旋转的对应角度,以度为单位。

使用 digitTrain4DArrayData 以四维数组形式加载训练图像。输出 XTrain 是一个 28×28×1×5000 数组,其中:

  • 28 是图像的高度和宽度。

  • 1 是通道数。

  • 5000 是由手写数字组成的合成图像的数目。

YTrain 包含以度为单位的旋转角度。

[XTrain,~,YTrain] = digitTrain4DArrayData;

使用 imshow 显示 20 个随机训练图像。

figure
numTrainImages = numel(YTrain);
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
    drawnow;
end

Figure contains 20 axes objects. Axes object 1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

指定卷积神经网络架构。对于回归问题,在网络末尾包含一个回归层。

layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(12,25)
    reluLayer
    fullyConnectedLayer(1)
    regressionLayer];

指定网络训练选项。将初始学习率设置为 0.001。

options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.001, ...
    'Verbose',false, ...
    'Plots','training-progress');

训练网络。

net = trainNetwork(XTrain,YTrain,layers,options);

Figure Training Progress (19-Aug-2023 11:41:24) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 7 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel RMSE contains 7 objects of type patch, text, line.

通过评估测试数据的预测准确度来测试网络的性能。使用 predict 预测验证图像的旋转角度。

[XTest,~,YTest] = digitTest4DArrayData;
YPred = predict(net,XTest);

通过计算预测旋转角度和实际旋转角度的均方根误差 (RMSE) 来评估模型的性能。

rmse = sqrt(mean((YTest - YPred).^2))
rmse = single
    6.0516

训练一个用于“序列到标签”分类的深度学习 LSTM 网络。

WaveformData.mat 加载示例数据。数据是序列的 numObservations×1 元胞数组,其中 numObservations 是序列数。每个序列都是一个 numChannels×-numTimeSteps 数值数组,其中 numChannels 是序列的通道数,numTimeSteps 是序列的时间步数。

load WaveformData

在绘图中可视化一些序列。

numChannels = size(data{1},1);

idx = [3 4 5 12];
figure
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(data{idx(i)}',DisplayLabels="Channel "+string(1:numChannels))
    
    xlabel("Time Step")
    title("Class: " + string(labels(idx(i))))
end

留出测试数据。将数据划分为训练集(包含 90% 数据)和测试集(包含其余 10% 数据)。要划分数据,请使用 trainingPartitions 函数,此函数作为支持文件包含在此示例中。要访问此文件,请以实时脚本形式打开此示例。

numObservations = numel(data);
[idxTrain,idxTest] = trainingPartitions(numObservations, [0.9 0.1]);
XTrain = data(idxTrain);
TTrain = labels(idxTrain);

XTest = data(idxTest);
TTest = labels(idxTest);

定义 LSTM 网络架构。将输入大小指定为输入数据的通道数量。指定一个 LSTM 层要具有 120 个隐藏单元,并输出序列的最后一个元素。最后,包括一个输出大小与类的数量匹配的全连接层,后跟一个 softmax 层和一个分类层。

numHiddenUnits = 120;
numClasses = numel(categories(TTrain));

layers = [ ...
    sequenceInputLayer(numChannels)
    lstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  5×1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 3 dimensions
     2   ''   LSTM                    LSTM with 120 hidden units
     3   ''   Fully Connected         4 fully connected layer
     4   ''   Softmax                 softmax
     5   ''   Classification Output   crossentropyex

指定训练选项。使用 Adam 求解器进行训练,学习率为 0.01,梯度阈值为 1。将最大训练轮数设置为 150,并对每轮执行乱序。默认情况下,软件会在 GPU 上(如果有)进行训练。使用 GPU 需要 Parallel Computing Toolbox 和支持的 GPU 设备。有关受支持设备的信息,请参阅GPU Computing Requirements (Parallel Computing Toolbox)

options = trainingOptions("adam", ...
    MaxEpochs=150, ...
    InitialLearnRate=0.01,...
    Shuffle="every-epoch", ...
    GradientThreshold=1, ...
    Verbose=false, ...
    Plots="training-progress");

使用指定的训练选项训练 LSTM 网络。

net = trainNetwork(XTrain,TTrain,layers,options);

对测试数据进行分类。指定用于训练的相同小批量大小。

YTest = classify(net,XTest);

计算预测值的分类准确度。

acc = mean(YTest == TTest)
acc = 0.8400

在混淆图中显示分类结果。

figure
confusionchart(TTest,YTest)

如果您有包含数值特征的数据集(例如,一个不包含空间或时间维度的数值数据集合),则可以使用特征输入层训练深度学习网络。

从 CSV 文件 "transmissionCasingData.csv" 中读取变速箱数据。

filename = "transmissionCasingData.csv";
tbl = readtable(filename,'TextType','String');

使用 convertvars 函数将用于预测的标签转换为分类标签。

labelName = "GearToothCondition";
tbl = convertvars(tbl,labelName,'categorical');

要使用分类特征训练网络,必须首先将分类特征转换为数值特征。首先,通过指定包含所有分类输入变量名称的字符串数组,使用 convertvars 函数将分类预测变量转换为分类值。在此数据集中,有两个分类特征,名称分别为 "SensorCondition""ShaftCondition"

categoricalInputNames = ["SensorCondition" "ShaftCondition"];
tbl = convertvars(tbl,categoricalInputNames,'categorical');

遍历分类输入变量。对于每个变量:

  • 使用 onehotencode 函数将分类值转换为 one-hot 编码向量。

  • 使用 addvars 函数将 one-hot 向量添加到表中。指定在包含对应分类数据的列后插入向量。

  • 删除包含分类数据的对应列。

for i = 1:numel(categoricalInputNames)
    name = categoricalInputNames(i);
    oh = onehotencode(tbl(:,name));
    tbl = addvars(tbl,oh,'After',name);
    tbl(:,name) = [];
end

使用 splitvars 函数将向量分成单独的列。

tbl = splitvars(tbl);

查看表的前几行。请注意,分类预测变量已分成多个列,并以分类值作为变量名称。

head(tbl)
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    No Sensor Drift    Sensor Drift    No Shaft Wear    Shaft Wear    GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    ____________    _____________    __________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13                0                1                1              0           No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12                0                1                1              0           No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39                0                1                0              1           No Tooth Fault  

查看数据集的类名称。

classNames = categories(tbl{:,labelName})
classNames = 2x1 cell
    {'No Tooth Fault'}
    {'Tooth Fault'   }

接下来,将数据集划分为训练分区和测试分区。留出 15% 的数据用于测试。

确定每个分区的观测值数目。

numObservations = size(tbl,1);
numObservationsTrain = floor(0.85*numObservations);
numObservationsTest = numObservations - numObservationsTrain;

创建一个与观测值对应的随机索引数组,并使用分区大小对其进行分区。

idx = randperm(numObservations);
idxTrain = idx(1:numObservationsTrain);
idxTest = idx(numObservationsTrain+1:end);

使用索引将数据表划分为训练分区和测试分区。

tblTrain = tbl(idxTrain,:);
tblTest = tbl(idxTest,:);

定义一个具有特征输入层的网络并指定特征的数量。此外,配置输入层以使用 Z 分数归一化对数据进行归一化。

numFeatures = size(tbl,2) - 1;
numClasses = numel(classNames);
 
layers = [
    featureInputLayer(numFeatures,'Normalization', 'zscore')
    fullyConnectedLayer(50)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

指定训练选项。

miniBatchSize = 16;

options = trainingOptions('adam', ...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'Plots','training-progress', ...
    'Verbose',false);

使用 layers 定义的架构、训练数据和训练选项训练网络。

net = trainNetwork(tblTrain,layers,options);

Figure Training Progress (19-Aug-2023 11:44:07) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 7 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel Accuracy (%) contains 7 objects of type patch, text, line.

使用经过训练的网络预测测试数据的标签,并计算准确度。准确度是网络正确预测的标签的比例。

YPred = classify(net,tblTest,'MiniBatchSize',miniBatchSize);
YTest = tblTest{:,labelName};

accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9688

输入参数

全部折叠

图像数据,指定为下列值之一:

数据类型描述用法示例
数据存储ImageDatastore保存在磁盘上的图像数据存储。

使用保存在磁盘上的图像训练图像分类神经网络,其中图像的大小相同。

当图像大小不同时,使用 AugmentedImageDatastore 对象。

ImageDatastore 对象仅支持图像分类任务。要将图像数据存储用于回归神经网络,请分别使用 transformcombine 函数创建包含图像和响应的变换或组合数据存储。

AugmentedImageDatastore应用随机仿射几何变换(包括调整大小、旋转、翻转、剪切和平移)的数据存储。

  • 使用保存在磁盘上的图像训练图像分类神经网络,这些图像的大小各不相同。

  • 训练图像分类神经网络并使用增强生成新数据。

TransformedDatastore这类数据存储使用自定义变换函数变换从基础数据存储中读取的批量数据。

  • 训练图像回归神经网络。

  • 训练具有多个输入的神经网络。

  • 变换具有不受 trainNetwork 支持的输出的数据存储。

  • 将自定义变换应用于数据存储输出。

CombinedDatastore从两个或多个基础数据存储中读取数据的数据存储。

  • 训练图像回归神经网络。

  • 训练具有多个输入的神经网络。

  • 合并来自不同数据源的预测变量和响应。

PixelLabelImageDatastore (Computer Vision Toolbox)将相同的仿射几何变换应用于图像和对应像素标签的数据存储。训练用于语义分割的神经网络。
RandomPatchExtractionDatastore (Image Processing Toolbox)数据存储,它从图像或像素标注图像中提取随机补片对组,并选择性地对这些补片对组应用相同的随机仿射几何变换。训练用于目标检测的神经网络。
DenoisingImageDatastore (Image Processing Toolbox)应用随机生成的高斯噪声的数据存储。训练用于图像去噪的神经网络。
自定义小批量数据存储返回小批量数据的自定义数据存储。

使用其他数据存储不支持的格式的数据训练神经网络。

有关详细信息,请参阅Develop Custom Mini-Batch Datastore

数值数组指定为数值数组的图像。如果将图像指定为数值数组,则还必须指定 responses 参数。使用可放入内存且不需要增强等额外处理的数据训练神经网络。
指定为表的图像。如果您将图像指定为表,则您还可以使用 responses 参数指定哪些列包含响应。使用存储在表中的数据训练神经网络。

对于具有多个输入的神经网络,数据存储必须为 TransformedDatastoreCombinedDatastore 对象。

提示

对于图像序列(例如视频数据),请使用 sequences 输入参数。

数据存储

数据存储用于读取小批量的图像和响应值。当您有无法放入内存的数据或要对数据应用增强或变换时,最适合使用数据存储。

对于图像数据,下表列出了直接与 trainNetwork 兼容的数据存储。

例如,您可以使用 imageDatastore 函数创建一个图像数据存储,并通过将 'LabelSource' 选项设置为 'foldernames' 来使用包含图像的文件夹的名称作为标签。您也可以使用图像数据存储的 Labels 属性手动指定标签。

提示

使用 augmentedImageDatastore 对要用于深度学习的图像进行高效预处理,包括调整图像大小。不要使用 ImageDatastore 对象的 ReadFcn 选项。

ImageDatastore 允许使用预取功能批量读取 JPG 或 PNG 图像文件。如果您将 ReadFcn 选项设置为自定义函数,则 ImageDatastore 不会预取,并且通常会明显变慢。

通过使用 transformcombine 函数,您可以使用其他内置数据存储来训练深度学习神经网络。这些函数可以将从数据存储中读取的数据转换为 trainNetwork 所需的格式。

对于具有多个输入的神经网络,数据存储必须为 TransformedDatastoreCombinedDatastore 对象。

数据存储输出所需的格式取决于神经网络架构。

神经网络架构数据存储输出示例输出
单个输入层

包含两列的表或元胞数组。

第一列和第二列分别指定预测变量和目标。

表元素必须为标量、行向量或包含数值数组的 1×1 元胞数组。

自定义小批量数据存储必须输出表。

对于具有一个输入和一个输出的神经网络,输出的表为:

data = read(ds)
data =

  4×2 table

        Predictors        Response
    __________________    ________

    {224×224×3 double}       2    
    {224×224×3 double}       7    
    {224×224×3 double}       9    
    {224×224×3 double}       9  

对于具有一个输入和一个输出的神经网络,输出的元胞数组为:

data = read(ds)
data =

  4×2 cell array

    {224×224×3 double}    {[2]}
    {224×224×3 double}    {[7]}
    {224×224×3 double}    {[9]}
    {224×224×3 double}    {[9]}

多个输入层

具有 (numInputs + 1) 列的元胞数组,其中 numInputs 是神经网络输入的数目。

numInputs 个列指定每个输入的预测变量,最后一列指定目标。

输入的顺序由层图 layersInputNames 属性给出。

对于具有双输入和单输出的神经网络,输出以下元胞数组。

data = read(ds)
data =

  4×3 cell array

    {224×224×3 double}    {128×128×3 double}    {[2]}
    {224×224×3 double}    {128×128×3 double}    {[2]}
    {224×224×3 double}    {128×128×3 double}    {[9]}
    {224×224×3 double}    {128×128×3 double}    {[9]}

预测变量的格式取决于数据的类型。

数据格式
二维图像

h×w×c× 数值数组,其中 h、w 和 c 分别是图像的高度、宽度和通道数。

三维图像h×w×d×c 数值数组,其中 h、w、d 和 c 分别是图像的高度、宽度、深度和通道数。

对于表中返回的预测变量,元素必须包含数值标量、数值行向量或包含数值数组的 1×1 元胞数组。

响应的格式取决于任务的类型。

任务响应格式
图像分类分类标量
图像回归
  • 数值标量

  • 数值向量

  • 表示二维图像的三维数值数组

  • 表示三维图像的四维数值数组

对于表中返回的响应,元素必须为分类标量、数值标量、数值行向量或包含数值数组的 1×1 元胞数组。

有关详细信息,请参阅Datastores for Deep Learning

数值数组

对于可放入内存并且不需要增强等额外处理的数据,您可以将图像数据集指定为数值数组。如果将图像指定为数值数组,则还必须指定 responses 参数。

数值数组的大小和形状取决于图像数据的类型。

数据格式
二维图像

h×w×c×N 数值数组,其中 h、w 和 c 分别是图像的高度、宽度和通道数,N 是图像的数量。

三维图像h×w×d×c×N 数值数组,其中 h、w、d 和 c 分别是图像的高度、宽度和通道数,N 是图像的数量。

作为数据存储或数值数组的替代方法,您还可以在表中指定图像和响应。如果您将图像指定为表,则您还可以使用 responses 参数指定哪些列包含响应。

当在表中指定图像和响应时,表中的每行对应一个观测值。

对于图像输入,预测变量必须位于表的第一列,指定为以下项之一:

  • 图像的绝对或相对文件路径,指定为字符向量

  • 1×1 元胞数组,包含表示二维图像的 h×w×c 数值数组,其中 h、w 和 c 分别对应于图像的高度、宽度和通道数。

响应的格式取决于任务的类型。

任务响应格式
图像分类分类标量
图像回归
  • 数值标量

  • 两列或多列标量值

  • 1×1 元胞数组,包含表示二维图像的 h×w×c 数值数组

  • 1×1 元胞数组,包含表示三维图像的 h×w×d×c 数值数组

对于具有图像输入的神经网络,如果不指定 responses,则默认情况下,该函数使用 tbl 的第一列作为预测变量,后续列作为响应。

提示

  • 如果预测变量或响应包含 NaN,则它们在训练期间会通过神经网络传播。在这些情况下,训练通常无法收敛。

  • 对于回归任务,将响应归一化通常有助于稳定和加速神经网络的回归训练。有关详细信息,请参阅针对回归训练卷积神经网络

  • 要将复数值数据输入到神经网络中,输入层的 SplitComplexInputs 选项必须为 1

数据类型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | table
复数支持:

序列或时间序列数据,指定为下列各项之一:

数据类型描述用法示例
数据存储TransformedDatastore这类数据存储使用自定义变换函数变换从基础数据存储中读取的批量数据。

  • 变换具有不受 trainNetwork 支持的输出的数据存储。

  • 将自定义变换应用于数据存储输出。

CombinedDatastore从两个或多个基础数据存储中读取数据的数据存储。

合并来自不同数据源的预测变量和响应。

自定义小批量数据存储返回小批量数据的自定义数据存储。

使用其他数据存储不支持的格式的数据训练神经网络。

有关详细信息,请参阅Develop Custom Mini-Batch Datastore

数值数组或元胞数组指定为数值数组的单个序列,或指定为由数值数组组成的元胞数组的序列数据集。如果将序列指定为数值或元胞数组,则还必须指定 responses 参数。使用可放入内存且不需要自定义变换等额外处理的数据训练神经网络。

数据存储

数据存储读取若干小批量序列和响应。当您有无法放入内存的数据或要对数据应用变换时,最适合使用数据存储。

对于序列数据,下表列出了直接与 trainNetwork 兼容的数据存储。

通过使用 transformcombine 函数,您可以使用其他内置数据存储来训练深度学习神经网络。这些函数可以将从数据存储中读取的数据转换为 trainNetwork 所需的表或元胞数组格式。例如,您可以分别使用 ArrayDatastoreTabularTextDatastore 对象变换和合并从内存数组与 CSV 文件中读取的数据。

数据存储必须以表或元胞数组的形式返回数据。自定义小批量数据存储必须输出表。

数据存储输出示例输出
data = read(ds)
data =

  4×2 table

        Predictors        Response
    __________________    ________

    {12×50 double}           2    
    {12×50 double}           7    
    {12×50 double}           9    
    {12×50 double}           9  
元胞数组
data = read(ds)
data =

  4×2 cell array

    {12×50 double}        {[2]}
    {12×50 double}        {[7]}
    {12×50 double}        {[9]}
    {12×50 double}        {[9]}

预测变量的格式取决于数据的类型。

数据预测变量的格式
向量序列

c×s 矩阵,其中 c 是序列的特征数,s 是序列长度。

一维图像序列

h×c×s 数组,其中 h 和 c 分别对应图像的高度和通道数,s 是序列长度。

小批量中的每个序列必须具有相同的序列长度。

二维图像序列

h×w×c×s 数组,其中 h、w 和 c 分别对应于图像的高度、宽度和通道数,s 是序列长度。

小批量中的每个序列必须具有相同的序列长度。

三维图像序列

h×w×d×c×s 数组,其中 h、w、d 和 c 分别对应于图像的高度、宽度、深度和通道数,而 s 是序列长度。

小批量中的每个序列必须具有相同的序列长度。

对于表中返回的预测变量,元素必须包含数值标量、数值行向量或包含数值数组的 1×1 元胞数组。

响应的格式取决于任务的类型。

任务响应的格式
“序列到标签”分类分类标量
“序列到单个”回归

标量

“序列到向量”回归

数值行向量

“序列到序列”分类

  • 1×s 分类标签序列,其中 s 是对应预测变量序列的序列长度。

  • h×w×s 分类标签序列,其中 h、w 和 s 分别是对应预测变量序列的高度、宽度和序列长度。

  • h×w×d×s 分类标签序列,其中 h、w、d 和 s 分别是对应预测变量序列的高度、宽度、深度和序列长度。

小批量中的每个序列必须具有相同的序列长度。

“序列到序列”回归
  • R×s 矩阵,其中 R 是响应的数量,s 是对应预测变量序列的序列长度。

  • h×w×R×s 数值响应序列,其中 R 是响应的数量,h、w 和 s 分别是对应预测变量序列的高度、宽度和序列长度。

  • h×w×d×R×s 数值响应序列,其中 R 是响应的数量,h、w、d 和 s 分别是对应预测变量序列的高度、宽度、深度和序列长度。

小批量中的每个序列必须具有相同的序列长度。

对于表中返回的响应,元素必须为分类标量、数值标量、数值行向量或包含数值数组的 1×1 元胞数组。

有关详细信息,请参阅Datastores for Deep Learning

数值数组或元胞数组

对于可放入内存并且不需要自定义变换等额外处理的数据,可以将单个序列指定为数值数组,或将序列数据集指定为由数值数组组成的元胞数组。如果将序列指定为元胞或数值数组,则还必须指定 responses 参数。

对于元胞数组输入,元胞数组必须为由数值数组组成的 N×1 元胞数组,其中 N 是观测值数目。表示序列的数值数组的大小和形状取决于序列数据的类型。

输入描述
向量序列c×s 矩阵,其中 c 是序列的特征数,s 是序列长度。
一维图像序列h×c×s 数组,其中 h 和 c 分别对应于图像的高度和通道数,而 s 是序列长度。
二维图像序列h×w×c×s 数组,其中 h、w 和 c 分别对应于图像的高度、宽度和通道数,s 是序列长度。
三维图像序列h×w×d×c×s,其中 h、w、d 和 c 分别对应于三维图像的高度、宽度、深度和通道数,s 是序列长度。

trainNetwork 函数支持最多具有一个序列输入层的神经网络。

提示

  • 如果预测变量或响应包含 NaN,则它们在训练期间会通过神经网络传播。在这些情况下,训练通常无法收敛。

  • 对于回归任务,将响应归一化通常有助于稳定和加速训练。有关详细信息,请参阅针对回归训练卷积神经网络

  • 要将复数值数据输入到神经网络中,输入层的 SplitComplexInputs 选项必须为 1

数据类型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | cell
复数支持:

特征数据,指定为下列各项之一:

数据类型描述用法示例
数据存储TransformedDatastore这类数据存储使用自定义变换函数变换从基础数据存储中读取的批量数据。

  • 训练具有多个输入的神经网络。

  • 变换具有不受 trainNetwork 支持的输出的数据存储。

  • 将自定义变换应用于数据存储输出。

CombinedDatastore从两个或多个基础数据存储中读取数据的数据存储。

  • 训练具有多个输入的神经网络。

  • 合并来自不同数据源的预测变量和响应。

自定义小批量数据存储返回小批量数据的自定义数据存储。

使用其他数据存储不支持的格式的数据训练神经网络。

有关详细信息,请参阅Develop Custom Mini-Batch Datastore

指定为表的特征数据。如果您将特征指定为表,则还可以使用 responses 参数指定哪些列包含响应。使用存储在表中的数据训练神经网络。
数值数组指定为数值数组的特征数据。如果将特征指定为数值数组,则还必须指定 responses 参数。使用可放入内存且不需要自定义变换等额外处理的数据训练神经网络。

数据存储

数据存储读取小批量的特征数据和响应。当您有无法放入内存的数据或要对数据应用变换时,最适合使用数据存储。

对于特征数据,下表列出了直接与 trainNetwork 兼容的数据存储。

通过使用 transformcombine 函数,您可以使用其他内置数据存储来训练深度学习神经网络。这些函数可以将从数据存储中读取的数据转换为 trainNetwork 所需的表或元胞数组格式。有关详细信息,请参阅Datastores for Deep Learning

对于具有多个输入的神经网络,数据存储必须为 TransformedDatastoreCombinedDatastore 对象。

数据存储必须以表或元胞数组的形式返回数据。自定义小批量数据存储必须输出表。数据存储输出的格式取决于神经网络架构。

神经网络架构数据存储输出示例输出
单个输入层

包含两列的表或元胞数组。

第一列和第二列分别指定预测变量和响应。

表元素必须为标量、行向量或包含数值数组的 1×1 元胞数组。

自定义小批量数据存储必须输出表。

对于具有一个输入和一个输出的神经网络,输出的表为:

data = read(ds)
data =

  4×2 table

        Predictors        Response
    __________________    ________

    {24×1 double}            2    
    {24×1 double}            7    
    {24×1 double}            9    
    {24×1 double}            9  

对于具有一个输入和一个输出的神经网络,输出的元胞数组为:

data = read(ds)
data =

  4×2 cell array

    {24×1 double}    {[2]}
    {24×1 double}    {[7]}
    {24×1 double}    {[9]}
    {24×1 double}    {[9]}

多个输入层

具有 (numInputs + 1) 列的元胞数组,其中 numInputs 是神经网络输入的数目。

numInputs 个列指定每个输入的预测变量,最后一列指定响应。

输入的顺序由层图 layersInputNames 属性给出。

对于具有双输入和单输出的神经网络,输出的元胞数组为:

data = read(ds)
data =

  4×3 cell array

    {24×1 double}    {28×1 double}    {[2]}
    {24×1 double}    {28×1 double}    {[2]}
    {24×1 double}    {28×1 double}    {[9]}
    {24×1 double}    {28×1 double}    {[9]}

预测变量必须为 c×1 列向量,其中 c 是特征数。

响应的格式取决于任务的类型。

任务响应的格式
分类分类标量
回归

  • 标量

  • 数值向量

有关详细信息,请参阅Datastores for Deep Learning

对于可放入内存且不需要自定义变换等其他处理的特征数据,可以将特征数据和响应指定为表。

表中的每行对应一个观测值。预测变量和响应在表列中的排列取决于任务的类型。

任务预测变量响应
特征分类

在一列或多列中指定为标量的特征。

如果未指定 responses 参数,则预测变量必须在表的前 numFeatures 个列中,其中 numFeatures 是输入数据的特征数。

分类标签

特征回归

一列或多列标量值

对于具有特征输入的分类神经网络,如果未指定 responses 参数,则默认情况下,该函数将 tbl 的前 (numColumns - 1) 个列用作预测变量,最后一列用作标签,其中 numFeatures 是输入数据中的特征数。

对于具有特征输入的回归神经网络,如果未指定 responseNames 参数,则默认情况下,该函数将前 numFeatures 个列用于预测变量,后续列用于响应,其中 numFeatures 是输入数据中的特征数。

数值数组

对于可放入内存且不需要自定义变换等额外处理的特征数据,可以将特征数据指定为数值数组。如果将特征数据指定为数值数组,则还必须指定 responses 参数。

数值数组必须为 N×numFeatures 的数值数组,其中 N 是观测值数目,numFeatures 是输入数据的特征数。

提示

  • 将响应归一化通常有助于稳定和加速神经网络的回归训练。有关详细信息,请参阅针对回归训练卷积神经网络

  • 响应不能包含 NaN。如果预测变量数据包含 NaN,则它们将通过训练进行传播。但在大多数情况下,训练将无法收敛。

  • 要将复数值数据输入到神经网络中,输入层的 SplitComplexInputs 选项必须为 1

数据类型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | table
复数支持:

混合数据和响应,指定为以下项之一:

数据类型描述用法示例
TransformedDatastore这类数据存储使用自定义变换函数变换从基础数据存储中读取的批量数据。

  • 训练具有多个输入的神经网络。

  • trainNetwork 不支持的数据存储的输出变换为具有要求的格式。

  • 将自定义变换应用于数据存储输出。

CombinedDatastore从两个或多个基础数据存储中读取数据的数据存储。

  • 训练具有多个输入的神经网络。

  • 合并来自不同数据源的预测变量和响应。

自定义小批量数据存储返回小批量数据的自定义数据存储。

使用其他数据存储不支持的格式的数据训练神经网络。

有关详细信息,请参阅Develop Custom Mini-Batch Datastore

通过使用 transformcombine 函数,您可以使用其他内置数据存储来训练深度学习神经网络。这些函数可以将从数据存储中读取的数据转换为 trainNetwork 所需的表或元胞数组格式。有关详细信息,请参阅Datastores for Deep Learning

数据存储必须以表或元胞数组的形式返回数据。自定义小批量数据存储必须输出表。数据存储输出的格式取决于神经网络架构。

数据存储输出示例输出

具有 (numInputs + 1) 列的元胞数组,其中 numInputs 是神经网络输入的数目。

numInputs 个列指定每个输入的预测变量,最后一列指定响应。

输入的顺序由层图 layersInputNames 属性给出。

data = read(ds)
data =

  4×3 cell array

    {24×1 double}    {28×1 double}    {[2]}
    {24×1 double}    {28×1 double}    {[2]}
    {24×1 double}    {28×1 double}    {[9]}
    {24×1 double}    {28×1 double}    {[9]}

对于图像、序列和特征预测变量输入,预测变量的格式必须分别与 imagessequencesfeatures 参数描述中所述的格式匹配。同样,响应的格式必须与 imagessequencesfeatures 参数描述中所述的与任务类型对应的格式匹配。

trainNetwork 函数支持最多具有一个序列输入层的神经网络。

有关如何训练具有多个输入的神经网络的示例,请参阅Train Network on Image and Feature Data

提示

  • 要将数值数组转换为数据存储,请使用 ArrayDatastore

  • 当合并具有混合类型数据的神经网络中的层时,您可能需要在将数据传递给合并层(如串联层或相加层)之前重新格式化数据。要重新格式化数据,您可以使用扁平化层将空间维度扁平化为通道维度,或创建一个 FunctionLayer 对象或自定义层来重新格式化和重构。

响应。

当输入数据是数值数组或元胞数组时,请将响应指定为以下项之一。

  • 由标签组成的分类向量

  • 由数值响应组成的数值数组

  • 由分类序列或由数值序列组成的元胞数组

当输入数据是表时,您可以选择指定表中的哪些列包含以下响应之一:

  • 字符向量

  • 字符向量元胞数组

  • 字符串数组

当输入数据是数值数组或元胞数组时,响应的格式取决于任务的类型。

任务格式
分类图像分类N×1 标签分类向量,其中 N 是观测值数目。
特征分类
“序列到标签”分类
“序列到序列”分类

由分类标签序列组成的 N×1 元胞数组,其中 N 是观测值数目。每个序列必须具有与对应预测变量序列相同的时间步数。

对于具有一个观测值的“序列到序列”分类任务,sequences 也可以是向量。在这种情况下,responses 必须为由标签组成的分类行向量。

回归二维图像回归
  • N×R 矩阵,其中 N 是图像的数量,R 是响应的数量。

  • h×w×c×N 数值数组,其中 h、w 和 c 分别是图像的高度、宽度和通道数,N 是图像的数量。

三维图像回归
  • N×R 矩阵,其中 N 是图像的数量,R 是响应的数量。

  • h×w×d×c×N 数值数组,其中 h、w、d 和 c 分别是图像的高度、宽度和通道数,N 是图像的数量。

特征回归

N×R 矩阵,其中 N 是观测值数目,R 是响应数。

“序列到单个”回归N×R 矩阵,其中 N 是序列数,R 是响应数。
“序列到序列”回归

由数值序列组成的 N×1 元胞数组,其中 N 是序列数,序列由以下项之一给出:

  • R×s 矩阵,其中 R 是响应的数量,s 是对应预测变量序列的序列长度。

  • h×w×R×s 数组,其中 h 和 w 分别是输出的高度和宽度,R 是响应数,s 是对应预测变量序列的序列长度。

  • h×w×d××R×s 数组,其中 h、w 和 d 分别是输出的高度、宽度和深度,R 是响应数,s 是对应预测变量序列的序列长度。

对于具有一个观测值的“序列到序列”回归任务,sequences 可以是数值数组。在这种情况下,responses 必须为由响应组成的数值数组。

提示

将响应归一化通常有助于稳定和加速神经网络的回归训练。有关详细信息,请参阅针对回归训练卷积神经网络

提示

响应不能包含 NaN。如果预测变量数据包含 NaN,则它们将通过训练进行传播。但在大多数情况下,训练将无法收敛。

神经网络层,指定为 Layer 数组或 LayerGraph 对象。

要创建一个所有层都按顺序连接的神经网络,您可以使用 Layer 数组作为输入参数。在这种情况下,返回的神经网络是一个 SeriesNetwork 对象。

有向无环图 (DAG) 神经网络具有复杂的结构,其中各层可以有多个输入和输出。要创建 DAG 神经网络,请将神经网络架构指定为 LayerGraph 对象,然后使用该层图作为 trainNetwork 的输入参数。

trainNetwork 函数支持最多具有一个序列输入层的神经网络。

有关内置层的列表,请参阅深度学习层列表

训练选项,指定为由 trainingOptions 函数返回的 TrainingOptionsSGDMTrainingOptionsRMSPropTrainingOptionsADAM 对象。

输出参数

全部折叠

经过训练的神经网络,以 SeriesNetwork 对象或 DAGNetwork 对象形式返回。

如果您使用 Layer 数组来训练神经网络,则 net 是一个 SeriesNetwork 对象。如果您使用 LayerGraph 对象来训练神经网络,则 net 是一个 DAGNetwork 对象。

训练信息,以结构体形式返回,其中每个字段是一个标量,或一个数值向量,向量中的每个元素对应一次训练迭代。

对于分类任务,info 包含以下字段:

  • TrainingLoss - 损失函数值

  • TrainingAccuracy - 训练准确度

  • ValidationLoss - 损失函数值

  • ValidationAccuracy - 验证准确度

  • BaseLearnRate - 学习率

  • FinalValidationLoss - 返回的神经网络的验证损失

  • FinalValidationAccuracy - 返回的神经网络的验证准确度

  • OutputNetworkIteration - 返回的神经网络的迭代序号

对于回归任务,info 包含以下字段:

  • TrainingLoss - 损失函数值

  • TrainingRMSE - 训练 RMSE 值

  • ValidationLoss - 损失函数值

  • ValidationRMSE - 验证 RMSE 值

  • BaseLearnRate - 学习率

  • FinalValidationLoss - 返回的神经网络的验证损失

  • FinalValidationRMSE - 返回的神经网络的验证 RMSE

  • OutputNetworkIteration - 返回的神经网络的迭代序号

如果 options 指定了验证数据,则该结构体仅包含 ValidationLossValidationAccuracyValidationRMSEFinalValidationLossFinalValidationAccuracyFinalValidationRMSE 字段。ValidationFrequency 训练选项确定软件在哪些迭代中计算验证度量。最终的验证度量是标量。该结构体的其他字段是行向量,其中每个元素对应一次训练迭代。对于软件不计算验证度量的迭代,结构体中的对应值是 NaN

对于包含批量归一化层的神经网络,如果 BatchNormalizationStatistics 训练选项为 'population',则最终验证度量通常不同于在训练期间评估出的验证度量。这是因为最终神经网络中的批量归一化层执行的操作与训练过程中执行的操作不同。有关详细信息,请参阅 batchNormalizationLayer

详细信息

全部折叠

保存检查点神经网络并继续训练

Deep Learning Toolbox™ 使您能够在训练期间将神经网络保存为 .mat 文件。当您有大型神经网络或大型数据集并且训练需要很长时间时,这种定期保存特别有用。如果训练因某种原因中断,您可以从上次保存的检查点神经网络继续训练。如果希望 trainnettrainNetwork 函数保存检查点神经网络,则您必须使用 trainingOptionsCheckpointPath 选项指定路径名称。如果您指定的路径不存在,则 trainingOptions 会返回错误。

软件自动为检查点神经网络文件分配唯一名称。在示例名称 net_checkpoint__351__2018_04_12__18_09_52.mat 中,351 是迭代序号,2018_04_12 是日期,18_09_52 是软件保存神经网络的时间。您可以通过双击检查点神经网络文件或在命令行中使用 load 命令来加载该文件。例如:

load net_checkpoint__351__2018_04_12__18_09_52.mat
然后,您可以使用神经网络的层作为 trainnettrainNetwork 的输入参数,继续进行训练。例如:

trainNetwork(XTrain,TTrain,net.Layers,options)
您必须手动指定训练选项和输入数据,因为检查点神经网络不包含此信息。有关示例,请参阅Resume Training from Checkpoint Network

浮点算术

当您使用 trainnettrainNetwork 函数训练神经网络时,或当您对 DAGNetworkSeriesNetwork 对象使用预测或验证函数时,软件会使用单精度浮点算术来执行这些计算。用于预测和验证的函数包括 predictclassifyactivations。当您使用 CPU 和 GPU 来训练神经网络时,软件将使用单精度算术。

可再现性

为了提供最优性能,在 MATLAB® 中使用 GPU 的深度学习不保证是确定性的。根据您的网络架构,在某些情况下,当使用 GPU 训练两个相同的网络或使用相同的网络和数据进行两次预测时,您可能会得到不同结果。

扩展功能

版本历史记录

在 R2016a 中推出

全部展开