Main Content

Make Predictions Using dlnetwork Object

This example shows how to make predictions using a dlnetwork object by looping over mini-batches.

For large data sets, or when predicting on hardware with limited memory, make predictions by looping over mini-batches of the data using the minibatchpredict function.

Load dlnetwork Object

Load a trained dlnetwork object and the corresponding class names into the workspace. The neural network has one input and two outputs. It takes images of handwritten digits as input, and predicts the digit label and angle of rotation.

load dlnetDigits

Load Data for Prediction

Load the digits test data for prediction.

load DigitsDataTest

View the class names.

classNames
classNames = 10x1 cell
    {'0'}
    {'1'}
    {'2'}
    {'3'}
    {'4'}
    {'5'}
    {'6'}
    {'7'}
    {'8'}
    {'9'}

View some of the images and the corresponding labels and angles of rotation.

numObservations = size(XTest,4);
numPlots = 9;
idx = randperm(numObservations,numPlots);

figure
for i = 1:numPlots
    nexttile(i)
    I = XTest(:,:,:,idx(i));
    label = labelsTest(idx(i));
    imshow(I)
    title("Label: " + string(label) + newline + "Angle: " + anglesTest(idx(i)))
end

Figure contains 9 axes objects. Hidden axes object 1 with title Label: 8 Angle: 5 contains an object of type image. Hidden axes object 2 with title Label: 9 Angle: -45 contains an object of type image. Hidden axes object 3 with title Label: 1 Angle: -11 contains an object of type image. Hidden axes object 4 with title Label: 9 Angle: -40 contains an object of type image. Hidden axes object 5 with title Label: 6 Angle: -42 contains an object of type image. Hidden axes object 6 with title Label: 0 Angle: -18 contains an object of type image. Hidden axes object 7 with title Label: 2 Angle: -9 contains an object of type image. Hidden axes object 8 with title Label: 5 Angle: -17 contains an object of type image. Hidden axes object 9 with title Label: 9 Angle: -27 contains an object of type image.

Make Predictions

Make predictions using the minibatchpredict function, and convert the classification scores to labels using the scores2label function. By default, the minibatchpredict function uses a GPU if one is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the function uses the CPU. To select the execution environment manually, use the ExecutionEnvironment argument of the minibatchpredict function.

[scoresTest,Y2Test] = minibatchpredict(net,XTest);
Y1Test = scores2label(scoresTest,classNames);

Visualize some of the predictions.

idx = randperm(numObservations,numPlots);

figure
for i = 1:numPlots
    nexttile(i)
    I = XTest(:,:,:,idx(i));
    label = Y1Test(idx(i));
    imshow(I)
    title("Label: " + string(label) + newline + "Angle: " + Y2Test(idx(i)))
end

Figure contains 9 axes objects. Hidden axes object 1 with title Label: 9 Angle: 20.3954 contains an object of type image. Hidden axes object 2 with title Label: 1 Angle: 3.7015 contains an object of type image. Hidden axes object 3 with title Label: 9 Angle: 23.5494 contains an object of type image. Hidden axes object 4 with title Label: 9 Angle: -36.4954 contains an object of type image. Hidden axes object 5 with title Label: 4 Angle: 16.428 contains an object of type image. Hidden axes object 6 with title Label: 7 Angle: 3.0644 contains an object of type image. Hidden axes object 7 with title Label: 1 Angle: 33.1356 contains an object of type image. Hidden axes object 8 with title Label: 4 Angle: 30.7531 contains an object of type image. Hidden axes object 9 with title Label: 9 Angle: 0.55887 contains an object of type image.

See Also

| | | |

Related Topics