trainNetwork
(不推荐)训练神经网络
语法
说明
示例
训练一个用于“序列到标签”分类的深度学习 LSTM 网络。
从 WaveformData.mat 加载示例数据。数据是序列的 numObservations×1 元胞数组,其中 numObservations 是序列数。每个序列都是一个 numChannels×-numTimeSteps 数值数组,其中 numChannels 是序列的通道数,numTimeSteps 是序列的时间步数。
load WaveformData在绘图中可视化一些序列。
numChannels = size(data{1},1);
idx = [3 4 5 12];
figure
tiledlayout(2,2)
for i = 1:4
nexttile
stackedplot(data{idx(i)}', ...
DisplayLabels="Channel " + string(1:numChannels))
xlabel("Time Step")
title("Class: " + string(labels(idx(i))))
end
留出测试数据。将数据划分为训练集(包含 90% 数据)和测试集(包含其余 10% 数据)。要划分数据,请使用 trainingPartitions 函数,此函数作为支持文件包含在此示例中。要访问此文件,请以实时脚本形式打开此示例。
numObservations = numel(data); [idxTrain,idxTest] = trainingPartitions(numObservations, [0.9 0.1]); XTrain = data(idxTrain); TTrain = labels(idxTrain); XTest = data(idxTest); TTest = labels(idxTest);
定义 LSTM 网络架构。将输入大小指定为输入数据的通道数量。指定一个 LSTM 层要具有 120 个隐藏单元,并输出序列的最后一个元素。最后,包括一个输出大小与类的数量匹配的全连接层,后跟一个 softmax 层和一个分类层。
numHiddenUnits = 120; numClasses = numel(categories(TTrain)); layers = [ ... sequenceInputLayer(numChannels) lstmLayer(numHiddenUnits,OutputMode="last") fullyConnectedLayer(numClasses) softmaxLayer classificationLayer]
layers =
5×1 Layer array with layers:
1 '' Sequence Input Sequence input with 3 dimensions
2 '' LSTM LSTM with 120 hidden units
3 '' Fully Connected 4 fully connected layer
4 '' Softmax softmax
5 '' Classification Output crossentropyex
指定训练选项。使用 Adam 求解器进行训练,学习率为 0.01,梯度阈值为 1。将最大训练轮数设置为 150,并对每轮执行乱序。默认情况下,软件会在 GPU 上(如果有)进行训练。使用 GPU 需要 Parallel Computing Toolbox™ 和受支持的 GPU 设备。有关受支持设备的信息,请参阅GPU 计算要求 (Parallel Computing Toolbox)。
options = trainingOptions("adam", ... MaxEpochs=150, ... InitialLearnRate=0.01,... Shuffle="every-epoch", ... GradientThreshold=1, ... Verbose=false, ... Plots="training-progress");
使用指定的训练选项训练 LSTM 网络。
net = trainNetwork(XTrain,TTrain,layers,options);

对测试数据进行分类。指定用于训练的相同小批量大小。
YTest = classify(net,XTest);
计算预测值的分类准确度。
acc = mean(YTest == TTest)
acc = 0.8400
在混淆图中显示分类结果。
figure confusionchart(TTest,YTest)

输入参数
图像数据,指定为下列值之一:
| 数据类型 | 描述 | 用法示例 | |
|---|---|---|---|
| 数据存储 | ImageDatastore | 保存在磁盘上的图像数据存储。 | 使用保存在磁盘上的图像训练图像分类神经网络,其中图像的大小相同。 当图像大小不同时,使用
|
AugmentedImageDatastore | 应用随机仿射几何变换(包括调整大小、旋转、翻转、剪切和平移)的数据存储。 |
| |
TransformedDatastore | 这类数据存储使用自定义变换函数变换从基础数据存储中读取的批量数据。 |
| |
CombinedDatastore | 从两个或多个基础数据存储中读取数据的数据存储。 |
| |
RandomPatchExtractionDatastore (Image Processing Toolbox) | 数据存储,它从图像或像素标注图像中提取随机补片对组,并选择性地对这些补片对组应用相同的随机仿射几何变换。 | 训练用于目标检测的神经网络。 | |
DenoisingImageDatastore (Image Processing Toolbox) | 应用随机生成的高斯噪声的数据存储。 | 训练用于图像去噪的神经网络。 | |
| 自定义小批量数据存储 | 返回小批量数据的自定义数据存储。 | 使用其他数据存储不支持的格式的数据训练神经网络。 有关详细信息,请参阅Develop Custom Mini-Batch Datastore。 | |
| 数值数组 | 指定为数值数组的图像。如果将图像指定为数值数组,则还必须指定 responses 参量。 | 使用可放入内存且不需要增强等额外处理的数据训练神经网络。 | |
| 表 | 指定为表的图像。如果您将图像指定为表,则您还可以使用 responses 参量指定哪些列包含响应。 | 使用存储在表中的数据训练神经网络。 | |
对于具有多个输入的神经网络,数据存储必须为 TransformedDatastore 或 CombinedDatastore 对象。
提示
对于图像序列(例如视频数据),请使用 sequences 输入参量。
数据存储
数据存储用于读取小批量的图像和响应值。当您有无法放入内存的数据或要对数据应用增强或变换时,最适合使用数据存储。
对于图像数据,下表列出了直接与 trainNetwork 兼容的数据存储。
RandomPatchExtractionDatastore(Image Processing Toolbox)DenoisingImageDatastore(Image Processing Toolbox)自定义小批量数据存储。有关详细信息,请参阅Develop Custom Mini-Batch Datastore。
例如,您可以使用 imageDatastore 函数创建一个图像数据存储,并通过将 'LabelSource' 选项设置为 'foldernames' 来使用包含图像的文件夹的名称作为标签。您也可以使用图像数据存储的 Labels 属性手动指定标签。
提示
使用 augmentedImageDatastore 对要用于深度学习的图像进行高效预处理,包括调整图像大小。不要使用 ImageDatastore 对象的 ReadFcn 选项。
ImageDatastore 允许使用预取功能批量读取 JPG 或 PNG 图像文件。如果您将 ReadFcn 选项设置为自定义函数,则 ImageDatastore 不会预取,并且通常会明显变慢。
通过使用 transform 和 combine 函数,您可以使用其他内置数据存储来训练深度学习神经网络。这些函数可以将从数据存储中读取的数据转换为 trainNetwork 所需的格式。
对于具有多个输入的神经网络,数据存储必须为 TransformedDatastore 或 CombinedDatastore 对象。
数据存储输出所需的格式取决于神经网络架构。
| 神经网络架构 | 数据存储输出 | 示例输出 |
|---|---|---|
| 单个输入层 | 包含两列的表或元胞数组。 第一列和第二列分别指定预测变量和目标。 表元素必须为标量、行向量或包含数值数组的 1×1 元胞数组。 自定义小批量数据存储必须输出表。 | 对于具有一个输入和一个输出的神经网络,输出的表为: data = read(ds) data =
4×2 table
Predictors Response
__________________ ________
{224×224×3 double} 2
{224×224×3 double} 7
{224×224×3 double} 9
{224×224×3 double} 9
|
对于具有一个输入和一个输出的神经网络,输出的元胞数组为: data = read(ds) data =
4×2 cell array
{224×224×3 double} {[2]}
{224×224×3 double} {[7]}
{224×224×3 double} {[9]}
{224×224×3 double} {[9]} | ||
| 多个输入层 | 具有 ( 前 输入的顺序由层图 | 对于具有双输入和单输出的神经网络,输出以下元胞数组。 data = read(ds) data =
4×3 cell array
{224×224×3 double} {128×128×3 double} {[2]}
{224×224×3 double} {128×128×3 double} {[2]}
{224×224×3 double} {128×128×3 double} {[9]}
{224×224×3 double} {128×128×3 double} {[9]} |
预测变量的格式取决于数据的类型。
| 数据 | 格式 |
|---|---|
| 二维图像 | h×w×c× 数值数组,其中 h、w 和 c 分别是图像的高度、宽度和通道数。 |
| 三维图像 | h×w×d×c 数值数组,其中 h、w、d 和 c 分别是图像的高度、宽度、深度和通道数。 |
对于表中返回的预测变量,元素必须包含数值标量、数值行向量或包含数值数组的 1×1 元胞数组。
响应的格式取决于任务的类型。
| 任务 | 响应格式 |
|---|---|
| 图像分类 | 分类标量 |
| 图像回归 |
|
对于表中返回的响应,元素必须为分类标量、数值标量、数值行向量或包含数值数组的 1×1 元胞数组。
有关详细信息,请参阅Datastores for Deep Learning。
数值数组
对于可放入内存并且不需要增强等额外处理的数据,您可以将图像数据集指定为数值数组。如果将图像指定为数值数组,则还必须指定 responses 参量。
数值数组的大小和形状取决于图像数据的类型。
| 数据 | 格式 |
|---|---|
| 二维图像 | h×w×c×N 数值数组,其中 h、w 和 c 分别是图像的高度、宽度和通道数,N 是图像的数量。 |
| 三维图像 | h×w×d×c×N 数值数组,其中 h、w、d 和 c 分别是图像的高度、宽度和通道数,N 是图像的数量。 |
表
作为数据存储或数值数组的替代方法,您还可以在表中指定图像和响应。如果您将图像指定为表,则您还可以使用 responses 参量指定哪些列包含响应。
当在表中指定图像和响应时,表中的每行对应一个观测值。
对于图像输入,预测变量必须位于表的第一列,指定为以下项之一:
图像的绝对或相对文件路径,指定为字符向量
1×1 元胞数组,包含表示二维图像的 h×w×c 数值数组,其中 h、w 和 c 分别对应于图像的高度、宽度和通道数。
响应的格式取决于任务的类型。
| 任务 | 响应格式 |
|---|---|
| 图像分类 | 分类标量 |
| 图像回归 |
|
对于具有图像输入的神经网络,如果不指定 responses,则默认情况下,该函数使用 tbl 的第一列作为预测变量,后续列作为响应。
提示
如果预测变量或响应包含
NaN,则它们在训练期间会通过神经网络传播。在这些情况下,训练通常无法收敛。对于回归任务,将响应归一化通常有助于稳定和加速神经网络的回归训练。有关详细信息,请参阅针对回归训练卷积神经网络。
此参量支持复数值预测变量。要使用
trainNetwork函数训练具有复数值预测变量的网络,输入层的SplitComplexInputs选项必须为1(true)。
序列或时间序列数据,指定为下列各项之一:
| 数据类型 | 描述 | 用法示例 | |
|---|---|---|---|
| 数据存储 | TransformedDatastore | 这类数据存储使用自定义变换函数变换从基础数据存储中读取的批量数据。 |
|
CombinedDatastore | 从两个或多个基础数据存储中读取数据的数据存储。 | 合并来自不同数据源的预测变量和响应。 | |
| 自定义小批量数据存储 | 返回小批量数据的自定义数据存储。 | 使用其他数据存储不支持的格式的数据训练神经网络。 有关详细信息,请参阅Develop Custom Mini-Batch Datastore。 | |
| 数值数组或元胞数组 | 指定为数值数组的单个序列,或指定为由数值数组组成的元胞数组的序列数据集。如果将序列指定为数值或元胞数组,则还必须指定 responses 参量。 | 使用可放入内存且不需要自定义变换等额外处理的数据训练神经网络。 | |
数据存储
数据存储读取若干小批量序列和响应。当您有无法放入内存的数据或要对数据应用变换时,最适合使用数据存储。
对于序列数据,下表列出了直接与 trainNetwork 兼容的数据存储。
自定义小批量数据存储。有关详细信息,请参阅Develop Custom Mini-Batch Datastore。
通过使用 transform 和 combine 函数,您可以使用其他内置数据存储来训练深度学习神经网络。这些函数可以将从数据存储中读取的数据转换为 trainNetwork 所需的表或元胞数组格式。例如,您可以分别使用 ArrayDatastore 和 TabularTextDatastore 对象变换和合并从内存数组与 CSV 文件中读取的数据。
数据存储必须以表或元胞数组的形式返回数据。自定义小批量数据存储必须输出表。
| 数据存储输出 | 示例输出 |
|---|---|
| 表 | data = read(ds) data =
4×2 table
Predictors Response
__________________ ________
{12×50 double} 2
{12×50 double} 7
{12×50 double} 9
{12×50 double} 9
|
| 元胞数组 | data = read(ds) data =
4×2 cell array
{12×50 double} {[2]}
{12×50 double} {[7]}
{12×50 double} {[9]}
{12×50 double} {[9]} |
预测变量的格式取决于数据的类型。
| 数据 | 预测变量的格式 |
|---|---|
| 向量序列 | c×s 矩阵,其中 c 是序列的特征数,s 是序列长度。 |
| 一维图像序列 | h×c×s 数组,其中 h 和 c 分别对应图像的高度和通道数,s 是序列长度。 小批量中的每个序列必须具有相同的序列长度。 |
| 二维图像序列 | h×w×c×s 数组,其中 h、w 和 c 分别对应于图像的高度、宽度和通道数,s 是序列长度。 小批量中的每个序列必须具有相同的序列长度。 |
| 三维图像序列 | h×w×d×c×s 数组,其中 h、w、d 和 c 分别对应于图像的高度、宽度、深度和通道数,而 s 是序列长度。 小批量中的每个序列必须具有相同的序列长度。 |
对于表中返回的预测变量,元素必须包含数值标量、数值行向量或包含数值数组的 1×1 元胞数组。
响应的格式取决于任务的类型。
| 任务 | 响应的格式 |
|---|---|
| “序列到标签”分类 | 分类标量 |
| “序列到单个”回归 | 标量 |
| “序列到向量”回归 | 数值行向量 |
| “序列到序列”分类 |
小批量中的每个序列必须具有相同的序列长度。 |
| “序列到序列”回归 |
小批量中的每个序列必须具有相同的序列长度。 |
对于表中返回的响应,元素必须为分类标量、数值标量、数值行向量或包含数值数组的 1×1 元胞数组。
有关详细信息,请参阅Datastores for Deep Learning。
数值数组或元胞数组
对于可放入内存并且不需要自定义变换等额外处理的数据,可以将单个序列指定为数值数组,或将序列数据集指定为由数值数组组成的元胞数组。如果将序列指定为元胞或数值数组,则还必须指定 responses 参量。
对于元胞数组输入,元胞数组必须为由数值数组组成的 N×1 元胞数组,其中 N 是观测值数目。表示序列的数值数组的大小和形状取决于序列数据的类型。
| 输入 | 描述 |
|---|---|
| 向量序列 | c×s 矩阵,其中 c 是序列的特征数,s 是序列长度。 |
| 一维图像序列 | h×c×s 数组,其中 h 和 c 分别对应于图像的高度和通道数,而 s 是序列长度。 |
| 二维图像序列 | h×w×c×s 数组,其中 h、w 和 c 分别对应于图像的高度、宽度和通道数,s 是序列长度。 |
| 三维图像序列 | h×w×d×c×s,其中 h、w、d 和 c 分别对应于三维图像的高度、宽度、深度和通道数,s 是序列长度。 |
trainNetwork 函数支持最多具有一个序列输入层的神经网络。
提示
如果预测变量或响应包含
NaN,则它们在训练期间会通过神经网络传播。在这些情况下,训练通常无法收敛。对于回归任务,将响应归一化通常有助于稳定和加速训练。有关详细信息,请参阅针对回归训练卷积神经网络。
此参量支持复数值预测变量。要使用
trainNetwork函数训练具有复数值预测变量的网络,输入层的SplitComplexInputs选项必须为1(true)。
特征数据,指定为下列各项之一:
| 数据类型 | 描述 | 用法示例 | |
|---|---|---|---|
| 数据存储 | TransformedDatastore | 这类数据存储使用自定义变换函数变换从基础数据存储中读取的批量数据。 |
|
CombinedDatastore | 从两个或多个基础数据存储中读取数据的数据存储。 |
| |
| 自定义小批量数据存储 | 返回小批量数据的自定义数据存储。 | 使用其他数据存储不支持的格式的数据训练神经网络。 有关详细信息,请参阅Develop Custom Mini-Batch Datastore。 | |
| 表 | 指定为表的特征数据。如果您将特征指定为表,则还可以使用 responses 参量指定哪些列包含响应。 | 使用存储在表中的数据训练神经网络。 | |
| 数值数组 | 指定为数值数组的特征数据。如果将特征指定为数值数组,则还必须指定 responses 参量。 | 使用可放入内存且不需要自定义变换等额外处理的数据训练神经网络。 | |
数据存储
数据存储读取小批量的特征数据和响应。当您有无法放入内存的数据或要对数据应用变换时,最适合使用数据存储。
对于特征数据,下表列出了直接与 trainNetwork 兼容的数据存储。
自定义小批量数据存储。有关详细信息,请参阅Develop Custom Mini-Batch Datastore。
通过使用 transform 和 combine 函数,您可以使用其他内置数据存储来训练深度学习神经网络。这些函数可以将从数据存储中读取的数据转换为 trainNetwork 所需的表或元胞数组格式。有关详细信息,请参阅Datastores for Deep Learning。
对于具有多个输入的神经网络,数据存储必须为 TransformedDatastore 或 CombinedDatastore 对象。
数据存储必须以表或元胞数组的形式返回数据。自定义小批量数据存储必须输出表。数据存储输出的格式取决于神经网络架构。
| 神经网络架构 | 数据存储输出 | 示例输出 |
|---|---|---|
| 单个输入层 | 包含两列的表或元胞数组。 第一列和第二列分别指定预测变量和响应。 表元素必须为标量、行向量或包含数值数组的 1×1 元胞数组。 自定义小批量数据存储必须输出表。 | 对于具有一个输入和一个输出的神经网络,输出的表为: data = read(ds) data =
4×2 table
Predictors Response
__________________ ________
{24×1 double} 2
{24×1 double} 7
{24×1 double} 9
{24×1 double} 9
|
对于具有一个输入和一个输出的神经网络,输出的元胞数组为:
data = read(ds) data =
4×2 cell array
{24×1 double} {[2]}
{24×1 double} {[7]}
{24×1 double} {[9]}
{24×1 double} {[9]} | ||
| 多个输入层 | 具有 ( 前 输入的顺序由层图 | 对于具有双输入和单输出的神经网络,输出的元胞数组为: data = read(ds) data =
4×3 cell array
{24×1 double} {28×1 double} {[2]}
{24×1 double} {28×1 double} {[2]}
{24×1 double} {28×1 double} {[9]}
{24×1 double} {28×1 double} {[9]} |
预测变量必须为 c×1 列向量,其中 c 是特征数。
响应的格式取决于任务的类型。
| 任务 | 响应的格式 |
|---|---|
| 分类 | 分类标量 |
| 回归 |
|
有关详细信息,请参阅Datastores for Deep Learning。
表
对于可放入内存且不需要自定义变换等其他处理的特征数据,可以将特征数据和响应指定为表。
表中的每行对应一个观测值。预测变量和响应在表列中的排列取决于任务的类型。
| 任务 | 预测变量 | 响应 |
|---|---|---|
| 特征分类 | 在一列或多列中指定为标量的特征。 如果未指定 | 分类标签 |
| 特征回归 | 一列或多列标量值 |
对于具有特征输入的分类神经网络,如果未指定 responses 参量,则默认情况下,该函数将 tbl 的前 (numColumns - 1) 个列用作预测变量,最后一列用作标签,其中 numFeatures 是输入数据中的特征数。
对于具有特征输入的回归神经网络,如果未指定 responseNames 参量,则默认情况下,该函数将前 numFeatures 个列用于预测变量,后续列用于响应,其中 numFeatures 是输入数据中的特征数。
数值数组
对于可放入内存且不需要自定义变换等额外处理的特征数据,可以将特征数据指定为数值数组。如果将特征数据指定为数值数组,则还必须指定 responses 参量。
数值数组必须为 N×numFeatures 的数值数组,其中 N 是观测值数目,numFeatures 是输入数据的特征数。
提示
将响应归一化通常有助于稳定和加速神经网络的回归训练。有关详细信息,请参阅针对回归训练卷积神经网络。
响应不能包含
NaN。如果预测变量数据包含NaN,则它们将通过训练进行传播。但在大多数情况下,训练将无法收敛。此参量支持复数值预测变量。要使用
trainNetwork函数训练具有复数值预测变量的网络,输入层的SplitComplexInputs选项必须为1(true)。
混合数据和响应,指定为以下项之一:
| 数据类型 | 描述 | 用法示例 |
|---|---|---|
TransformedDatastore | 这类数据存储使用自定义变换函数变换从基础数据存储中读取的批量数据。 |
|
CombinedDatastore | 从两个或多个基础数据存储中读取数据的数据存储。 |
|
| 自定义小批量数据存储 | 返回小批量数据的自定义数据存储。 | 使用其他数据存储不支持的格式的数据训练神经网络。 有关详细信息,请参阅Develop Custom Mini-Batch Datastore。 |
通过使用 transform 和 combine 函数,您可以使用其他内置数据存储来训练深度学习神经网络。这些函数可以将从数据存储中读取的数据转换为 trainNetwork 所需的表或元胞数组格式。有关详细信息,请参阅Datastores for Deep Learning。
数据存储必须以表或元胞数组的形式返回数据。自定义小批量数据存储必须输出表。数据存储输出的格式取决于神经网络架构。
| 数据存储输出 | 示例输出 |
|---|---|
具有 ( 前 输入的顺序由层图 | data = read(ds) data =
4×3 cell array
{24×1 double} {28×1 double} {[2]}
{24×1 double} {28×1 double} {[2]}
{24×1 double} {28×1 double} {[9]}
{24×1 double} {28×1 double} {[9]} |
对于图像、序列和特征预测变量输入,预测变量的格式必须分别与 images、sequences 或 features 参量描述中所述的格式匹配。同样,响应的格式必须与 images、sequences 或 features 参量描述中所述的与任务类型对应的格式匹配。
trainNetwork 函数支持最多具有一个序列输入层的神经网络。
有关如何训练具有多个输入的神经网络的示例,请参阅基于图像和特征数据训练网络。
提示
要将数值数组转换为数据存储,请使用
ArrayDatastore。当合并具有混合类型数据的神经网络中的层时,您可能需要在将数据传递给合并层(如串联层或加法层)之前重新格式化数据。要重新格式化数据,您可以使用展平层将空间维度展平为通道维度,或创建一个
FunctionLayer对象或自定义层来重新格式化和重构。此参量支持复数值预测变量。要使用
trainNetwork函数训练具有复数值预测变量的网络,输入层的SplitComplexInputs选项必须为1(true)。
响应。
当输入数据是数值数组或元胞数组时,请将响应指定为以下项之一。
由标签组成的分类向量
由数值响应组成的数值数组
由分类序列或由数值序列组成的元胞数组
当输入数据是表时,您可以选择指定表中的哪些列包含以下响应之一:
字符向量
字符向量元胞数组
字符串数组
当输入数据是数值数组或元胞数组时,响应的格式取决于任务的类型。
| 任务 | 格式 | |
|---|---|---|
| 分类 | 图像分类 | N×1 标签分类向量,其中 N 是观测值数目。 |
| 特征分类 | ||
| “序列到标签”分类 | ||
| “序列到序列”分类 | 由分类标签序列组成的 N×1 元胞数组,其中 N 是观测值数目。每个序列必须具有与对应预测变量序列相同的时间步数。 对于具有一个观测值的“序列到序列”分类任务, | |
| 回归 | 二维图像回归 |
|
| 三维图像回归 |
| |
| 特征回归 | N×R 矩阵,其中 N 是观测值数目,R 是响应数。 | |
| “序列到单个”回归 | N×R 矩阵,其中 N 是序列数,R 是响应数。 | |
| “序列到序列”回归 | 由数值序列组成的 N×1 元胞数组,其中 N 是序列数,序列由以下项之一给出:
对于具有一个观测值的“序列到序列”回归任务, | |
提示
将响应归一化通常有助于稳定和加速神经网络的回归训练。有关详细信息,请参阅针对回归训练卷积神经网络。
提示
响应不能包含 NaN。如果预测变量数据包含 NaN,则它们将通过训练进行传播。但在大多数情况下,训练将无法收敛。
神经网络层,指定为 Layer 数组或 LayerGraph 对象。
要创建一个所有层都按顺序连接的神经网络,您可以使用 Layer 数组作为输入参量。在这种情况下,返回的神经网络是一个 SeriesNetwork 对象。
有向无环图 (DAG) 神经网络具有复杂的结构,其中各层可以有多个输入和输出。要创建 DAG 神经网络,请将神经网络架构指定为 LayerGraph 对象,然后使用该层图作为 trainNetwork 的输入参量。
trainNetwork 函数支持最多具有一个序列输入层的神经网络。
有关内置层的列表,请参阅深度学习层列表。
训练选项,指定为由 trainingOptions 函数返回的 TrainingOptionsSGDM、TrainingOptionsRMSProp 或 TrainingOptionsADAM 对象。
输出参量
经过训练的神经网络,以 SeriesNetwork 对象或 DAGNetwork 对象形式返回。
如果您使用 Layer 数组来训练神经网络,则 net 是一个 SeriesNetwork 对象。如果您使用 LayerGraph 对象来训练神经网络,则 net 是一个 DAGNetwork 对象。
训练信息,以结构体形式返回,其中每个字段是一个标量,或一个数值向量,向量中的每个元素对应一次训练迭代。
对于分类任务,info 包含以下字段:
TrainingLoss- 损失函数值TrainingAccuracy- 训练准确度ValidationLoss- 损失函数值ValidationAccuracy- 验证准确度BaseLearnRate- 学习率FinalValidationLoss- 返回的神经网络的验证损失FinalValidationAccuracy- 返回的神经网络的验证准确度OutputNetworkIteration- 返回的神经网络的迭代序号
对于回归任务,info 包含以下字段:
TrainingLoss- 损失函数值TrainingRMSE- 训练 RMSE 值ValidationLoss- 损失函数值ValidationRMSE- 验证 RMSE 值BaseLearnRate- 学习率FinalValidationLoss- 返回的神经网络的验证损失FinalValidationRMSE- 返回的神经网络的验证 RMSEOutputNetworkIteration- 返回的神经网络的迭代序号
如果 options 指定了验证数据,则该结构体仅包含 ValidationLoss、ValidationAccuracy、ValidationRMSE、FinalValidationLoss、FinalValidationAccuracy 和 FinalValidationRMSE 字段。ValidationFrequency 训练选项确定软件在哪些迭代中计算验证度量。最终的验证度量是标量。该结构体的其他字段是行向量,其中每个元素对应一次训练迭代。对于软件不计算验证度量的迭代,结构体中的对应值是 NaN。
对于包含批量归一化层的神经网络,如果 BatchNormalizationStatistics 训练选项为 'population',则最终验证度量通常不同于在训练期间评估出的验证度量。这是因为最终神经网络中的批量归一化层执行的操作与训练过程中执行的操作不同。有关详细信息,请参阅 batchNormalizationLayer。
详细信息
Deep Learning Toolbox™ 使您能够在训练期间将神经网络保存为 .mat 文件。当您有大型神经网络或大型数据集并且训练需要很长时间时,这种定期保存特别有用。如果训练因某种原因中断,您可以从上次保存的检查点神经网络继续训练。如果希望 trainNetwork 函数保存检查点神经网络,则您必须使用 trainingOptions 的 CheckpointPath 选项指定路径名称。如果您指定的路径不存在,则 trainingOptions 会返回错误。
软件自动为检查点神经网络文件分配唯一名称。在示例名称 net_checkpoint__351__2018_04_12__18_09_52.mat 中,351 是迭代序号,2018_04_12 是日期,18_09_52 是软件保存神经网络的时间。您可以通过双击检查点神经网络文件或在命令行中使用 load 命令来加载该文件。例如:
load net_checkpoint__351__2018_04_12__18_09_52.mat
trainNetwork 的输入参量,继续进行训练。例如:trainNetwork(XTrain,TTrain,net.Layers,options)
当您使用 trainnet 或 trainNetwork 函数训练神经网络时,或当您对 DAGNetwork 和 SeriesNetwork 对象使用预测或验证函数时,软件会使用单精度浮点算术来执行这些计算。用于预测和验证的函数包括 predict、classify 和 activations。当您使用 CPU 和 GPU 来训练神经网络时,软件将使用单精度算术。
为了提供最优性能,在 MATLAB® 中使用 GPU 的深度学习不保证是确定性的。根据您的网络架构,在某些情况下,当使用 GPU 训练两个相同的网络或使用相同的网络和数据进行两次预测时,您可能会得到不同结果。
扩展功能
要并行运行计算,请将 ExecutionEnvironment 训练选项设置为 "multi-gpu" 或 "parallel"。
使用 trainingOptions 设置 ExecutionEnvironment 训练选项并向 trainNetwork 提供选项。如果未设置 ExecutionEnvironment,则 trainNetwork 在 GPU 上运行(如果可用)。
有关详细信息,请参阅Scale Up Deep Learning in Parallel, on GPUs, and in the Cloud。
为了防止出现内存不足错误,建议不要将大型训练数据集移到 GPU 上。更好的做法是,通过使用
trainingOptions将ExecutionEnvironment设置为"auto"或"gpu"并向trainNetwork提供选项,在 GPU 上训练您的神经网络。当输入数据为以下值时,
ExecutionEnvironment选项必须为"auto"或"gpu":gpuArray包含
gpuArray对象的元胞数组包含
gpuArray对象的表输出包含
gpuArray对象的元胞数组的数据存储输出包含
gpuArray对象的表的数据存储
有关详细信息,请参阅在 GPU 上运行 MATLAB 函数 (Parallel Computing Toolbox)。
版本历史记录
在 R2016a 中推出从 R2024a 开始,不推荐使用 trainNetwork 函数,请改用 trainnet 函数。
目前没有停止支持 trainNetwork 函数的计划。但是,推荐改用 trainnet 函数,该函数具有以下优势:
trainnet支持dlnetwork对象,这些对象支持更广泛的网络架构,您可以创建或从外部平台导入这些网络架构。trainnet使您能够轻松指定损失函数。您可以从内置损失函数中进行选择或指定自定义损失函数。trainnet输出dlnetwork对象,这是一种统一的数据类型,支持网络构建、预测、内置训练、可视化、压缩、验证和自定义训练循环。trainnet通常比trainNetwork快。
下表显示了 trainNetwork 函数的一些典型用法,以及如何更新您的代码以改用 trainnet 函数。
| 不推荐 | 推荐 |
|---|---|
net = trainNetwork(data,layers,options); | net = trainnet(data,layers,lossFcn,options); |
net = trainNetwork(X,T,layers,options); | net = trainnet(X,T,layers,lossFcn,options); |
不使用输出层,而是使用 lossFcn 指定损失函数。
从 R2022b 开始,如果您使用 trainNetwork 函数基于序列数据训练神经网络,并且 SequenceLength 选项设置为整数,则软件会将序列填充到每个小批量中最长序列的长度,然后再将这些序列拆分为具有指定序列长度的小批量。如果 SequenceLength 未均分小批量的序列长度,则最后拆分的小批量的长度短于 SequenceLength。此行为会阻止神经网络在仅包含填充值的时间步上进行训练。
在以前的版本中,软件会填充小批量序列,使其长度与大于或等于小批量长度的 SequenceLength 的最邻近倍数匹配,然后拆分数据。要重现此行为,请使用自定义训练循环,并在预处理小批量数据时实现此行为。
当使用 trainNetwork 函数训练神经网络时,训练会在损失为 NaN 时自动停止。通常,NaN 的损失值会将 NaN 值引入到神经网络可学习参数中,这又会导致神经网络无法训练或无法作出有效的预测。这一更改有助于在训练完成之前识别神经网络的问题。
在以前的版本中,神经网络在损失为 NaN 时会继续训练。
在以后的版本中,将不再支持在为 trainNetwork 函数指定序列数据时指定 MAT 文件路径表。
要使用无法放入内存的序列训练神经网络,请使用数据存储。您可以使用任何数据存储来读取数据,然后使用 transform 函数将数据存储输出变换为 trainNetwork 函数要求的格式。例如,您可以使用 FileDatastore 或 TabularTextDatastore 对象读取数据,然后使用 transform 函数变换输出。
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
选择网站
选择网站以获取翻译的可用内容,以及查看当地活动和优惠。根据您的位置,我们建议您选择:。
您也可以从以下列表中选择网站:
如何获得最佳网站性能
选择中国网站(中文或英文)以获得最佳网站性能。其他 MathWorks 国家/地区网站并未针对您所在位置的访问进行优化。
美洲
- América Latina (Español)
- Canada (English)
- United States (English)
欧洲
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)