Make Predictions Using dlnetwork
Object
This example shows how to make predictions using a dlnetwork
object by looping over mini-batches.
For large data sets, or when predicting on hardware with limited memory, make predictions by looping over mini-batches of the data using the minibatchpredict
function.
Load dlnetwork
Object
Load a trained dlnetwork
object and the corresponding class names into the workspace. The neural network has one input and two outputs. It takes images of handwritten digits as input, and predicts the digit label and angle of rotation.
load dlnetDigits
Load Data for Prediction
Load the digits test data for prediction.
load DigitsDataTest
View the class names.
classNames
classNames = 10x1 cell
{'0'}
{'1'}
{'2'}
{'3'}
{'4'}
{'5'}
{'6'}
{'7'}
{'8'}
{'9'}
View some of the images and the corresponding labels and angles of rotation.
numObservations = size(XTest,4); numPlots = 9; idx = randperm(numObservations,numPlots); figure for i = 1:numPlots nexttile(i) I = XTest(:,:,:,idx(i)); label = labelsTest(idx(i)); imshow(I) title("Label: " + string(label) + newline + "Angle: " + anglesTest(idx(i))) end
Make Predictions
Make predictions using the minibatchpredict
function, and convert the classification scores to labels using the scores2label
function. By default, the minibatchpredict
function uses a GPU if one is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the function uses the CPU. To select the execution environment manually, use the ExecutionEnvironment
argument of the minibatchpredict
function.
[scoresTest,Y2Test] = minibatchpredict(net,XTest); Y1Test = scores2label(scoresTest,classNames);
Visualize some of the predictions.
idx = randperm(numObservations,numPlots); figure for i = 1:numPlots nexttile(i) I = XTest(:,:,:,idx(i)); label = Y1Test(idx(i)); imshow(I) title("Label: " + string(label) + newline + "Angle: " + Y2Test(idx(i))) end
See Also
dlarray
| dlnetwork
| predict
| minibatchqueue
| onehotdecode
Related Topics
- Train Generative Adversarial Network (GAN)
- Train Network Using Custom Training Loop
- Define Model Loss Function for Custom Training Loop
- Update Batch Normalization Statistics in Custom Training Loop
- Define Custom Training Loops, Loss Functions, and Networks
- Make Predictions Using Model Function
- Specify Training Options in Custom Training Loop
- List of Deep Learning Layers
- Deep Learning Tips and Tricks
- Automatic Differentiation Background