使用 dlnetwork
对象进行预测
此示例说明如何通过遍历小批量,使用 dlnetwork
对象进行预测。
对于大型数据集,或者在内存有限的硬件上进行预测时,使用 minibatchpredict
函数通过遍历小批量数据进行预测。
加载 dlnetwork
对象
将经过训练的 dlnetwork
对象和相应的类名称加载到工作区中。该神经网络有一个输入和两个输出。它将手写数字图像作为输入,并预测数字标签和旋转角度。
load dlnetDigits
加载数据进行预测
加载数字测试数据进行预测。
load DigitsDataTest
查看类名称。
classNames
classNames = 10×1 cell
{'0'}
{'1'}
{'2'}
{'3'}
{'4'}
{'5'}
{'6'}
{'7'}
{'8'}
{'9'}
查看其中一些图像以及相应的标签和旋转角度。
numObservations = size(XTest,4); numPlots = 9; idx = randperm(numObservations,numPlots); figure for i = 1:numPlots nexttile(i) I = XTest(:,:,:,idx(i)); label = labelsTest(idx(i)); imshow(I) title("Label: " + string(label) + newline + "Angle: " + anglesTest(idx(i))) end
进行预测
使用 minibatchpredict
函数进行预测,并使用 scores2label
函数将分类分数转换为标签。默认情况下,minibatchpredict
函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。有关受支持设备的信息,请参阅GPU 计算要求 (Parallel Computing Toolbox)。否则,该函数使用 CPU。要手动选择执行环境,请使用 minibatchpredict
函数的 ExecutionEnvironment
参量。
[scoresTest,Y2Test] = minibatchpredict(net,XTest); Y1Test = scores2label(scoresTest,classNames);
可视化一些预测值。
idx = randperm(numObservations,numPlots); figure for i = 1:numPlots nexttile(i) I = XTest(:,:,:,idx(i)); label = Y1Test(idx(i)); imshow(I) title("Label: " + string(label) + newline + "Angle: " + Y2Test(idx(i))) end
另请参阅
dlarray
| dlnetwork
| predict
| minibatchqueue
| onehotdecode
主题
- 训练生成对抗网络 (GAN)
- 使用自定义训练循环训练网络
- Define Model Loss Function for Custom Training Loop
- Update Batch Normalization Statistics in Custom Training Loop
- 定义自定义训练循环、损失函数和网络
- Make Predictions Using Model Function
- Specify Training Options in Custom Training Loop
- 深度学习层列表
- Deep Learning Tips and Tricks
- Automatic Differentiation Background