主要内容

本页翻译不是最新的。点击此处可查看最新英文版本。

使用 dlnetwork 对象进行预测

此示例说明如何通过遍历小批量,使用 dlnetwork 对象进行预测。

对于大型数据集,或者在内存有限的硬件上进行预测时,使用 minibatchpredict 函数通过遍历小批量数据进行预测。

加载 dlnetwork 对象

将经过训练的 dlnetwork 对象和相应的类名称加载到工作区中。该神经网络有一个输入和两个输出。它将手写数字图像作为输入,并预测数字标签和旋转角度。

load dlnetDigits

加载数据进行预测

加载数字测试数据进行预测。

load DigitsDataTest

查看类名称。

classNames
classNames = 10×1 cell
    {'0'}
    {'1'}
    {'2'}
    {'3'}
    {'4'}
    {'5'}
    {'6'}
    {'7'}
    {'8'}
    {'9'}

查看其中一些图像以及相应的标签和旋转角度。

numObservations = size(XTest,4);
numPlots = 9;
idx = randperm(numObservations,numPlots);

figure
for i = 1:numPlots
    nexttile(i)
    I = XTest(:,:,:,idx(i));
    label = labelsTest(idx(i));
    imshow(I)
    title("Label: " + string(label) + newline + "Angle: " + anglesTest(idx(i)))
end

Figure contains 9 axes objects. Hidden axes object 1 with title Label: 8 Angle: 5 contains an object of type image. Hidden axes object 2 with title Label: 9 Angle: -45 contains an object of type image. Hidden axes object 3 with title Label: 1 Angle: -11 contains an object of type image. Hidden axes object 4 with title Label: 9 Angle: -40 contains an object of type image. Hidden axes object 5 with title Label: 6 Angle: -42 contains an object of type image. Hidden axes object 6 with title Label: 0 Angle: -18 contains an object of type image. Hidden axes object 7 with title Label: 2 Angle: -9 contains an object of type image. Hidden axes object 8 with title Label: 5 Angle: -17 contains an object of type image. Hidden axes object 9 with title Label: 9 Angle: -27 contains an object of type image.

进行预测

使用 minibatchpredict 函数进行预测,并使用 scores2label 函数将分类分数转换为标签。默认情况下,minibatchpredict 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。有关受支持设备的信息,请参阅GPU 计算要求 (Parallel Computing Toolbox)。否则,该函数使用 CPU。要手动选择执行环境,请使用 minibatchpredict 函数的 ExecutionEnvironment 参量。

[scoresTest,Y2Test] = minibatchpredict(net,XTest);
Y1Test = scores2label(scoresTest,classNames);

可视化一些预测值。

idx = randperm(numObservations,numPlots);

figure
for i = 1:numPlots
    nexttile(i)
    I = XTest(:,:,:,idx(i));
    label = Y1Test(idx(i));
    imshow(I)
    title("Label: " + string(label) + newline + "Angle: " + Y2Test(idx(i)))
end

Figure contains 9 axes objects. Hidden axes object 1 with title Label: 9 Angle: 20.3954 contains an object of type image. Hidden axes object 2 with title Label: 1 Angle: 3.7015 contains an object of type image. Hidden axes object 3 with title Label: 9 Angle: 23.5494 contains an object of type image. Hidden axes object 4 with title Label: 9 Angle: -36.4954 contains an object of type image. Hidden axes object 5 with title Label: 4 Angle: 16.428 contains an object of type image. Hidden axes object 6 with title Label: 7 Angle: 3.0644 contains an object of type image. Hidden axes object 7 with title Label: 1 Angle: 33.1356 contains an object of type image. Hidden axes object 8 with title Label: 4 Angle: 30.7531 contains an object of type image. Hidden axes object 9 with title Label: 9 Angle: 0.55887 contains an object of type image.

另请参阅

| | | |

主题