Main Content

Predict

使用经过训练的深度学习神经网络预测响应

自 R2020b 起

  • Predict block

库:
Deep Learning Toolbox / Deep Neural Networks

描述

Predict 模块使用通过模块参数指定的已训练网络来预测输入端数据的响应。此模块允许从 MAT 文件或使用 MATLAB® 函数将预训练网络加载到 Simulink® 模型中。

注意

在 Simulink 中使用 Predict 模块进行预测。要使用 MATLAB 代码以编程方式进行预测,请使用 classifypredict 函数。

端口

输入

全部展开

Predict 模块的输入端口接受所加载网络的输入层名称。例如,如果您为 MATLAB function 指定 googlenet,则 Predict 模块的输入端口将标注为数据。根据加载的网络,Predict 模块的输入可以是图像、序列或时间序列数据。

输入的格式取决于数据的类型。

数据预测变量的格式
二维图像h×w×c×N 数值数组,其中 h、w 和 c 分别是图像的高度、宽度和通道数,N 是图像的数量。
向量序列c×s 矩阵,其中 c 是序列的特征数,s 是序列长度。
二维图像序列h×w×c×s 数组,其中 h、w 和 c 分别对应于图像的高度、宽度和通道数,s 是序列长度。
特征numFeatures 数值数组,其中 N 是观测值数目,numFeatures 是输入数据的特征数。

如果数组包含 NaN,则它们会通过网络传播。

输出

全部展开

Predict 模块的输出端口接受所加载网络的输出层名称。例如,如果您为 MATLAB function 指定 googlenet,则 Predict 模块的输出端口标注为输出。基于加载的网络,Predict 模块的输出可以表示预测的分数或响应。

预测的分数或响应,以 N×K 数组形式返回,其中 N 是观测值数目,K 是类数。

如果为一个网络层启用 Activations,则 Predict 模块会使用所选网络层的名称创建一个新输出端口。此端口输出来自所选网络层的激活值。

网络层的激活值以数值数组形式返回。输出格式取决于输入数据的类型和层输出的类型。

对于二维图像输出,activations 是一个 h×w×c×n 数组,其中 h、w 和 c 分别是所选层输出的高度、宽度和通道数,n 是图像的数量。

对于包含向量数据的单个时间步,激活值是一个 c×n 矩阵,其中 n 是序列的数量,c 是序列中特征的数量。

对于包含二维图像数据的单个时间步,activations 是一个 h×w×c×n 数组,其中 n 是序列数,h、w 和 c 分别是图像的高度、宽度和通道数。

参数

全部展开

指定经过训练的网络的源。选择下列项之一:

  • 从 MAT 文件创建网络 - 从包含 SeriesNetworkDAGNetworkdlnetwork 对象的 MAT 文件中导入经过训练的网络。

  • 从 MATLAB 函数创建网络 - 从 MATLAB 函数导入预训练网络。例如,通过使用 googlenet 函数。

编程用法

模块参数:Network
类型:字符向量、字符串
值:'Network from MAT-file' | 'Network from MATLAB function'
默认值: 'Network from MAT-file'

此参数指定包含要加载的经过训练的深度学习网络的 MAT 文件名称。如果该文件不在 MATLAB 路径中,请使用浏览按钮找到该文件。

依存关系

要启用此参数,请将网络参数设置为从 MAT 文件创建网络

编程用法

模块参数:NetworkFilePath
类型:字符向量、字符串
值:MAT 文件路径或名称
默认值: 'untitled.mat'

此参数指定预训练深度学习网络的 MATLAB 函数的名称。例如,使用 googlenet 函数导入预训练的 GoogLeNet 模型。

依存关系

要启用此参数,请将网络参数设置为从 MATLAB 函数创建网络

编程用法

模块参数:NetworkFunction
类型:字符向量、字符串
值:MATLAB 函数名称
默认值: 'squeezenet'

用于预测的小批量的大小,指定为正整数。小批量大小越大,需要的内存越多,但预测速度可能更快。

编程用法

模块参数:MiniBatchSize
类型:字符向量、字符串
值:正整数
默认值: '128'

启用返回预测的分数或响应的输出端口。

编程用法

模块参数:Predictions
类型:字符向量、字符串
值:'off' | 'on'
默认值: 'on'

此参数指定经过训练的 dlnetwork 要求的输入数据格式。

数据格式是字符串,其中每个字符描述数据的对应维度的类型。例如,对于包含一批序列的数组,其中第一个、第二个和第三个维度分别对应于通道、观测值和时间步,您可以指定其格式为 "CBT"。有关详细信息,请参阅Deep Learning Data Formats

依存关系

要启用此参数,请将网络参数设置为从 MAT 文件创建网络,以便从 MAT 文件导入经过训练的 dlnetwork 对象。

编程用法

模块参数:InputDataFormats
类型:字符向量、字符串
值:对于具有一个或多个输入的网络,请使用 {'inputlayerName1', 'SSC'; 'inputlayerName2', 'SSCB'; ...}' 形式的字符向量。对于没有输入层但有多个输入端口的网络,请使用 '{'inputportName1/inport1, 'SSC'; 'inputportName2/inport2, 'SSCB'; ...}' 形式的字符向量。
默认值: ''

使用激活列表选择要从中提取特征的层。所选层显示为 Predict 模块的输出端口。

编程用法

模块参数:Activations
类型:字符向量、字符串
值:'{'layerName1',layerName2',...}' 形式的字符向量
默认值: ''

扩展功能

版本历史记录

在 R2020b 中推出