predict
计算深度学习网络输出以进行推断
说明
一些深度学习层在训练期间的行为和其在推断(预测)期间的行为是不同的。例如,在训练期间,丢弃层会随机将输入元素设置为零以帮助防止过拟合,但在推断期间,丢弃层不会更改输入。
示例
将预训练的 SqueezeNet 神经网络加载到工作区中。
[net,classNames] = imagePretrainedNetwork;
从 PNG 文件中读取图像并对其进行分类。要对图像进行分类,请先将其数据类型转换为 single。
im = imread("peppers.png");
figure
imshow(im)
X = single(im); scores = predict(net,X); [label,score] = scores2label(scores,classNames);
显示具有预测标签和对应分数的图像。
figure imshow(im) title(string(label) + " (Score: " + score + ")")

输入参数
神经网络,指定为下列值之一:
dlnetwork对象 - 神经网络。TaylorPrunableNetwork对象 - 用于自定义剪枝循环的神经网络。
要对深度神经网络进行剪枝,您需要 Deep Learning Toolbox™ Model Compression Library 支持包。该支持包是一项免费附加功能,您可以使用附加功能资源管理器进行下载。或者,请参阅 Deep Learning Toolbox Model Compression Library。
输入数据,每个输入数据指定为下列值之一:
格式化的
dlarray对象未格式化的
dlarray对象 (自 R2023b 起)数值数组 (自 R2023b 起)
提示
神经网络需要输入数据具有特定的布局。例如,向量序列分类网络通常需要向量序列表示形式为 t×c 数组,其中 t 和 c 分别是序列的时间步数和通道数。神经网络通常具有一个输入层,用于指定数据的预期布局。
大多数数据存储和函数会以网络需要的布局输出数据。如果您的数据布局与网络预期的布局不同,则可通过使用 InputDataFormats 选项或将输入数据指定为格式化的 dlarray 对象来表明您的数据具有不同的布局。通常,调整 InputDataFormats 训练选项比预处理输入数据更为便捷。
有关详细信息,请参阅Deep Learning Data Formats。
要创建接收未格式化数据的神经网络,请使用 inputLayer 对象并且不指定格式。要将未格式化的数据直接输入网络,请不要指定 InputDataFormats 参量。 (自 R2025a 起)
在 R2025a 之前的版本中: 对于没有输入层的神经网络,您必须使用 InputDataFormats 参量指定一种格式。
名称-值参数
将可选参量对组指定为 Name1=Value1,...,NameN=ValueN,其中 Name 是参量名称,Value 是对应的值。名称-值参量必须出现在其他参量之后,但对各个参量对组的顺序没有要求。
如果使用的是 R2021a 之前的版本,请使用逗号分隔每个名称和值,并用引号将 Name 引起来。
示例: Y = predict(net,X,InputDataFormats="CBT") 使用格式为 "CBT"(通道、批量、时间)的序列数据进行预测。
神经网络输出,指定为层名称或层输出路径的字符串数组或字符向量元胞数组。使用以下形式之一指定输出:
"layerName",其中layerName是具有单个输出的层的名称。"layerName/outputName",其中layerName是层的名称,outputName是层输出的名称。对于具有多个输出的层,请使用此选项。
要使用 networkLayer 对象内部的层的输出,请先使用 expandLayers 函数展开嵌套网络。有关详细信息,请参阅网络层提示。
如果您未指定要从中提取输出的层,则默认情况下,软件会使用 net.Outputs 指定的输出。
自 R2023b 起
输入数据维度的描述,指定为字符串数组、字符向量或字符向量元胞数组。
如果 InputDataFormats 为 "auto",则软件会使用网络输入所需要的格式。否则,软件会为对应的网络输入使用指定的格式。
数据格式是一个字符串,其中每个字符描述对应数据的维度的类型。
这些字符是:
"S"- 空间"C"- 通道"B"- 批量"T"- 时间"U"- 未指定
例如,假设有一个表示一批序列的数组,其中第一个、第二个和第三个维度分别对应于通道、观测值和时间步。您可以将该数据描述为具有格式 "CBT"(通道、批量、时间)。
您可以指定多个标注为 "S" 或 "U" 的维度。每个 "C"、"B" 和 "T" 标签最多可以使用一次。该软件忽略第二个维度后的单一尾部 "U" 维度。
对于具有多个输入的神经网络 net,请指定一个输入数据格式数组,其中 InputDataFormats(i) 对应于输入 net.InputNames(i)。
有关详细信息,请参阅Deep Learning Data Formats。
要创建接收未格式化数据的神经网络,请使用 inputLayer 对象并且不指定格式。要将未格式化的数据直接输入网络,请不要指定 InputDataFormats 参量。 (自 R2025a 起)
在 R2025a 之前的版本中: 对于没有输入层的神经网络,您必须使用 InputDataFormats 参量指定一种格式。
数据类型: char | string | cell
自 R2023b 起
输出数据维度的描述,指定为下列值之一:
"auto"- 如果输出数据与输入数据的维数相同,则predict函数使用InputDataFormats指定的格式。如果输出数据与输入数据的维数不同,则predict函数会自动对输出数据的维度进行置换,使其与网络输入层或InputDataFormats值保持一致。字符串、字符向量或字符向量元胞数组 -
predict函数使用指定的数据格式。
数据格式是一个字符串,其中每个字符描述对应数据的维度的类型。
这些字符是:
"S"- 空间"C"- 通道"B"- 批量"T"- 时间"U"- 未指定
例如,假设有一个表示一批序列的数组,其中第一个、第二个和第三个维度分别对应于通道、观测值和时间步。您可以将该数据描述为具有格式 "CBT"(通道、批量、时间)。
您可以指定多个标注为 "S" 或 "U" 的维度。每个 "C"、"B" 和 "T" 标签最多可以使用一次。该软件忽略第二个维度后的单一尾部 "U" 维度。
有关详细信息,请参阅Deep Learning Data Formats。
数据类型: char | string | cell
性能优化,指定为下列值之一:
"auto"- 自动应用适用于输入网络和硬件资源的多项优化。"mex"- 编译并执行 MEX 函数。此选项仅在使用 GPU 时可用。您必须将输入数据或网络可学习参数存储为gpuArray对象。使用 GPU 需要 Parallel Computing Toolbox™ 和支持的 GPU 设备。有关受支持设备的信息,请参阅GPU 计算要求 (Parallel Computing Toolbox)。如果 Parallel Computing Toolbox 或合适的 GPU 不可用,则软件会返回错误。"none"- 禁用所有加速。
当您使用 "auto" 或 "mex" 选项时,软件可以提供性能优势,但会增加初始运行时间。后续调用该函数通常会更快。当您使用不同的输入数据多次调用该函数时,请使用性能优化。
当 Acceleration 为 "mex" 时,软件会根据您在函数调用中指定的模型和参数生成并执行 MEX 函数。单个模型可以同时关联多个 MEX 函数。清除模型变量也会清除与该模型关联的所有 MEX 函数。
当 Acceleration 为 "auto" 时,软件不会生成 MEX 函数。
"mex" 选项仅在使用 GPU 时可用。您必须安装 C/C++ 编译器和 GPU Coder™ Interface for Deep Learning 支持包。使用 MATLAB® 中的附加功能资源管理器安装该支持包。有关设置说明,请参阅 Set Up Compiler (GPU Coder)。GPU Coder 不是必需的。
"mex" 选项具有以下限制:
不支持
state输出参量。仅支持
single精度。输入数据或网络可学习参数的基础类型必须为single。不支持输入未连接到输入层的网络。
并非所有层都受支持。要查看支持的层的列表,请参阅Supported Layers (GPU Coder)。
当您使用
"mex"选项时,MATLAB Compiler™ 不支持部署网络。
对于量化网络,"mex" 选项需要具有 6.1、6.3 或更高计算能力的支持 CUDA® 的 NVIDIA® GPU。
输出参量
具有多个输出的网络的输出数据,以下列值之一的形式返回:
格式化的
dlarray对象未格式化的
dlarray对象 (自 R2023b 起)数值数组 (自 R2023b 起)
数据类型与输入数据的数据类型匹配。
输出 Y1, …, YN 的顺序与 Outputs 参量指定的输出顺序匹配。
对于分类神经网络,输出的元素对应于每个类的分数。分数的顺序与训练数据中类别的顺序匹配。例如,如果您使用分类标签 TTrain 训练神经网络,则分数的顺序与 categories(TTrain) 给出的类别顺序匹配。
算法
为了提供最优性能,在 MATLAB 中使用 GPU 的深度学习不保证是确定性的。根据您的网络架构,在某些情况下,当使用 GPU 训练两个相同的网络或使用相同的网络和数据进行两次预测时,您可能会得到不同结果。如果在使用 GPU 执行深度学习运算时需要确定性,请使用 deep.gpu.deterministicAlgorithms 函数 (自 R2024b 起)。
如果您使用 rng 函数设置相同的随机数生成器和种子,则使用 CPU 进行的预测是可重现的。
扩展功能
用法说明和限制:
C++ 代码生成支持以下语法:
Y = predict(net,X)Y = predict(net,X1,...,XM)[Y1,...,YN] = predict(__)[Y1,...,YK] = predict(__,'Outputs',layerNames)
对于语法
[__,state] = predict(__),您可以生成不依赖于任何第三方库的通用 C/C++ 代码。代码生成支持调整
State属性的Value变量。代码生成不支持修改State属性的Layer和Parameter变量。代码生成支持以下用于
State属性的函数:对于 Simulink 仿真,代码生成不支持在 MATLAB Function 模块中提取和更新
dlnetwork的State。请改用 Stateful Predict 或 Stateful Classify 模块。输入数据
X只能在时间 ("T") 维度上具有可变大小。输入数据X的其他数据维度不能具有可变大小。大小必须在代码生成时固定。代码生成不支持向
dlnetwork对象的predict方法传递复数值输入。predict方法的dlarray输入必须是single数据类型。
用法说明和限制:
GPU 代码生成支持以下语法:
Y = predict(net,X)Y = predict(net,X1,...,XM)[Y1,...,YN] = predict(__)[Y1,...,YK] = predict(__,'Outputs',layerNames)
对于语法
[__,state] = predict(__),您可以生成独立于深度学习库的纯 CUDA 代码。代码生成支持调整
State属性的Value变量。代码生成不支持修改State属性的Layer和Parameter变量。代码生成支持以下用于
State属性的函数:对于 Simulink 仿真,代码生成不支持在 MATLAB Function 模块中提取和更新
dlnetwork的State。请改用 Stateful Predict 或 Stateful Classify 模块。输入数据
X只能在时间 ("T") 维度上具有可变大小。输入数据X的其他数据维度不能具有可变大小。大小必须在代码生成时固定。针对 TensorRT 库的代码生成不支持使用
[Y1,...,YK] = predict(__,'Outputs',layerNames)语法将输入层标记为输出。代码生成不支持向
dlnetwork对象的predict方法传递复数值输入。predict方法的dlarray输入必须是single数据类型。
predict 函数支持 GPU 数组输入,但有以下用法说明和限制:
如果满足以下至少一个条件,此函数将在 GPU 上运行:
net.Learnables.Value中的任何网络可学习参数值是基础数据类型为gpuArray的dlarray对象。输入参量
X是基础数据类型为gpuArray的dlarray。
有关详细信息,请参阅在 GPU 上运行 MATLAB 函数 (Parallel Computing Toolbox)。
版本历史记录
在 R2019b 中推出如果您将未格式化数据指定为神经网络的输入且未指定 InputDataFormats 参量,则该函数会直接将未格式化数据传递给网络。
要创建接收未格式化数据的神经网络,请使用 inputLayer 对象并且不指定格式。
使用数值数组和未格式化的 dlarray 对象进行预测。
分别使用 InputDataFormats 选项和 OutputDataFormats 选项指定输入数据格式和输出数据格式。
对于 dlnetwork 对象,predict 函数返回的 state 输出参量是一个包含网络中每个层的状态参数名称和值的表。
从 R2021a 开始,状态值是 dlarray 对象。进行此更改后可在使用 AcceleratedFunction 对象时提供更好的支持。为了加速具有频繁更改的输入值(例如包含网络状态的输入)的深度学习函数,必须将频繁更改的值指定为 dlarray 对象。
在以前的版本中,状态值是数值数组。
在大多数情况下,您不需要更新代码。如果您的代码要求状态值是数值数组,要重现以前的行为,请使用 extractdata 函数和 dlupdate 函数手动从状态值中提取数据。
state = dlupdate(@extractdata,net.State);
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
选择网站
选择网站以获取翻译的可用内容,以及查看当地活动和优惠。根据您的位置,我们建议您选择:。
您也可以从以下列表中选择网站:
如何获得最佳网站性能
选择中国网站(中文或英文)以获得最佳网站性能。其他 MathWorks 国家/地区网站并未针对您所在位置的访问进行优化。
美洲
- América Latina (Español)
- Canada (English)
- United States (English)
欧洲
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)