主要内容

多输入和多输出网络

在 Deep Learning Toolbox™ 中,您可以定义具有多个输入(例如,在多个数据源和类型的数据上训练的网络)或多个输出(例如,同时预测分类响应和回归响应的网络)的网络架构。

多输入网络

当网络需要来自多个数据源或具有不同格式的数据时,应定义具有多个输入的网络。例如,需要从多个传感器捕获的具有不同分辨率的图像数据的网络。

要定义和训练具有多个输入的深度学习网络,请使用 dlnetwork 对象指定网络架构,并使用 trainnet 函数进行训练。

要对经过训练的具有多个输入的深度学习网络进行预测,请使用 minibatchpredict 函数。使用下列项之一指定多个输入:

  • combinedDatastore 对象

  • transformedDatastore 对象

  • 多个数值数组

有关说明如何训练具有图像输入和特征输入的网络的示例,请参阅基于图像和特征数据训练网络

多输出网络

对于需要不同格式的多个响应的任务,应定义具有多个输出的网络。例如,既需要分类输出又需要数值输出的任务。

要训练具有多个输出的深度学习网络,请结合使用 trainnet 函数与自定义损失函数。例如,要定义一个损失,该损失对应于预测标签和目标标签的交叉熵损失加上预测数值响应和目标数值响应的均方误差,请使用以下损失函数:

lossFcn = @(Y1,Y2,T1,T2) crossentropy(Y1,T1) + mse(Y2,T2);

使用 trainnet 函数结合自定义损失函数训练神经网络。

net = trainnet(dsTrain,net,lossFcn,options);

要对经过训练的具有多个输出的深度学习网络进行预测,请使用 minibatchpredict 函数。

有关示例,请参阅训练具有多个输出的网络

对多输入和多输出网络使用数据存储

要训练具有多个输入层或多个输出的网络,请使用 combinetransform 函数创建一个数据存储,该数据存储输出一个包含 (numInputs + numOutputs) 列的元胞数组,其中 numInputs 是网络输入的数量,numOutputs 是网络输出的数量。前 numInputs 列指定每个输入的预测变量,后 numOutputs 列指定响应。神经网络的 InputNamesOutputNames 属性分别确定输入和输出的顺序。

对于使用 minibatchpredict 函数进行的推断,只要数据存储的读取函数返回与预测变量对应的列,就表示该数据存储是有效的。minibatchpredict 函数使用前 numInputs 列,而忽略后续列,其中 numInputs 是网络输入层的数量。

另请参阅

| | | | |

主题