Train neural network for deep learning
Use trainNetwork
to train a convolutional neural network
(ConvNet, CNN), a long shortterm memory (LSTM) network, or a bidirectional LSTM
(BiLSTM) network for deep learning classification and regression problems. You can train
a network on either a CPU or a GPU. For image classification and image regression, you
can train using multiple GPUs or in parallel. Using GPU, multiGPU, and parallel options
requires Parallel
Computing Toolbox™. To use a GPU for deep
learning, you must also have a CUDA^{®} enabled NVIDIA^{®} GPU with compute capability 3.0 or higher. Specify training options, including options for the execution environment,
by using trainingOptions
.
trains an network for sequence classification and regression problems (for
example, an LSTM or BiLSTM network), where net
= trainNetwork(sequences
,Y
,layers
,options
)sequences
contains sequence or time series predictors and Y
contains
the responses. For classification problems, Y
is a
categorical vector or a cell array of categorical sequences. For regression
problems, Y
is a matrix of targets or a cell array of
numeric sequences.
trains a network for classification and regression problems. The predictors must
be in the first column of net
= trainNetwork(tbl
,responseName
,layers
,options
)tbl
. The
responseName
argument specifies the response variables
in tbl
.
Load the data as an ImageDatastore
object.
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet', ... 'nndemos','nndatasets','DigitDataset'); imds = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames');
The datastore contains 10,000 synthetic images of digits from 0 to 9. The images are generated by applying random transformations to digit images created with different fonts. Each digit image is 28by28 pixels. The datastore contains an equal number of images per category.
Display some of the images in the datastore.
figure numImages = 10000; perm = randperm(numImages,20); for i = 1:20 subplot(4,5,i); imshow(imds.Files{perm(i)}); end
Divide the datastore so that each category in the training set has 750 images and the testing set has the remaining images from each label.
numTrainingFiles = 750;
[imdsTrain,imdsTest] = splitEachLabel(imds,numTrainingFiles,'randomize');
splitEachLabel
splits the image files in digitData
into two new datastores, imdsTrain
and imdsTest
.
Define the convolutional neural network architecture.
layers = [ ... imageInputLayer([28 28 1]) convolution2dLayer(5,20) reluLayer maxPooling2dLayer(2,'Stride',2) fullyConnectedLayer(10) softmaxLayer classificationLayer];
Set the options to the default settings for the stochastic gradient descent with momentum. Set the maximum number of epochs at 20, and start the training with an initial learning rate of 0.0001.
options = trainingOptions('sgdm', ... 'MaxEpochs',20,... 'InitialLearnRate',1e4, ... 'Verbose',false, ... 'Plots','trainingprogress');
Train the network.
net = trainNetwork(imdsTrain,layers,options);
Run the trained network on the test set, which was not used to train the network, and predict the image labels (digits).
YPred = classify(net,imdsTest); YTest = imdsTest.Labels;
Calculate the accuracy. The accuracy is the ratio of the number of true labels in the test data matching the classifications from classify
to the number of images in the test data.
accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9404
Train a convolutional neural network using augmented image data. Data augmentation helps prevent the network from overfitting and memorizing the exact details of the training images.
Load the sample data, which consists of synthetic images of handwritten digits.
[XTrain,YTrain] = digitTrain4DArrayData;
digitTrain4DArrayData
loads the digit training set as 4D array data. XTrain
is a 28by28by1by5000 array, where:
28 is the height and width of the images.
1 is the number of channels.
5000 is the number of synthetic images of handwritten digits.
YTrain
is a categorical vector containing the labels for each observation.
Set aside 1000 of the images for network validation.
idx = randperm(size(XTrain,4),1000); XValidation = XTrain(:,:,:,idx); XTrain(:,:,:,idx) = []; YValidation = YTrain(idx); YTrain(idx) = [];
Create an imageDataAugmenter
object that specifies preprocessing options for image augmentation, such as resizing, rotation, translation, and reflection. Randomly translate the images up to three pixels horizontally and vertically, and rotate the images with an angle up to 20 degrees.
imageAugmenter = imageDataAugmenter( ... 'RandRotation',[20,20], ... 'RandXTranslation',[3 3], ... 'RandYTranslation',[3 3])
imageAugmenter = imageDataAugmenter with properties: FillValue: 0 RandXReflection: 0 RandYReflection: 0 RandRotation: [20 20] RandScale: [1 1] RandXScale: [1 1] RandYScale: [1 1] RandXShear: [0 0] RandYShear: [0 0] RandXTranslation: [3 3] RandYTranslation: [3 3]
Create an augmentedImageDatastore
object to use for network training and specify the image output size. During training, the datastore performs image augmentation and resizes the images. The datastore augments the images without saving any images to memory. trainNetwork
updates the network parameters and then discards the augmented images.
imageSize = [28 28 1];
augimds = augmentedImageDatastore(imageSize,XTrain,YTrain,'DataAugmentation',imageAugmenter);
Specify the convolutional neural network architecture.
layers = [ imageInputLayer(imageSize) convolution2dLayer(3,8,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,16,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer fullyConnectedLayer(10) softmaxLayer classificationLayer];
Specify training options for stochastic gradient descent with momentum.
opts = trainingOptions('sgdm', ... 'MaxEpochs',15, ... 'Shuffle','everyepoch', ... 'Plots','trainingprogress', ... 'Verbose',false, ... 'ValidationData',{XValidation,YValidation});
Train the network. Because the validation images are not augmented, the validation accuracy is higher than the training accuracy.
net = trainNetwork(augimds,layers,opts);
Load the sample data, which consists of synthetic images of handwritten digits. The third output contains the corresponding angles in degrees by which each image has been rotated.
Load the training images as 4D arrays using digitTrain4DArrayData
. The output XTrain
is a 28by28by1by5000 array, where:
28 is the height and width of the images.
1 is the number of channels.
5000 is the number of synthetic images of handwritten digits.
YTrain
contains the rotation angles in degrees.
[XTrain,~,YTrain] = digitTrain4DArrayData;
Display 20 random training images using imshow
.
figure numTrainImages = numel(YTrain); idx = randperm(numTrainImages,20); for i = 1:numel(idx) subplot(4,5,i) imshow(XTrain(:,:,:,idx(i))) end
Specify the convolutional neural network architecture. For regression problems, include a regression layer at the end of the network.
layers = [ ...
imageInputLayer([28 28 1])
convolution2dLayer(12,25)
reluLayer
fullyConnectedLayer(1)
regressionLayer];
Specify the network training options. Set the initial learn rate to 0.001.
options = trainingOptions('sgdm', ... 'InitialLearnRate',0.001, ... 'Verbose',false, ... 'Plots','trainingprogress');
Train the network.
net = trainNetwork(XTrain,YTrain,layers,options);
Test the performance of the network by evaluating the prediction accuracy of the test data. Use predict
to predict the angles of rotation of the validation images.
[XTest,~,YTest] = digitTest4DArrayData; YPred = predict(net,XTest);
Evaluate the performance of the model by calculating the rootmeansquare error (RMSE) of the predicted and actual angles of rotation.
rmse = sqrt(mean((YTest  YPred).^2))
rmse = single
6.0388
Train a deep learning LSTM network for sequencetolabel classification.
Load the Japanese Vowels data set as described in [1] and [2]. XTrain
is a cell array containing 270 sequences of varying length with a feature dimension of 12. Y
is a categorical vector of labels 1,2,...,9. The entries in XTrain
are matrices with 12 rows (one row for each feature) and a varying number of columns (one column for each time step).
[XTrain,YTrain] = japaneseVowelsTrainData;
Visualize the first time series in a plot. Each line corresponds to a feature.
figure plot(XTrain{1}') title("Training Observation 1") numFeatures = size(XTrain{1},1); legend("Feature " + string(1:numFeatures),'Location','northeastoutside')
Define the LSTM network architecture. Specify the input size as 12 (the number of features of the input data). Specify an LSTM layer to have 100 hidden units and to output the last element of the sequence. Finally, specify nine classes by including a fully connected layer of size 9, followed by a softmax layer and a classification layer.
inputSize = 12; numHiddenUnits = 100; numClasses = 9; layers = [ ... sequenceInputLayer(inputSize) lstmLayer(numHiddenUnits,'OutputMode','last') fullyConnectedLayer(numClasses) softmaxLayer classificationLayer]
layers = 5x1 Layer array with layers: 1 '' Sequence Input Sequence input with 12 dimensions 2 '' LSTM LSTM with 100 hidden units 3 '' Fully Connected 9 fully connected layer 4 '' Softmax softmax 5 '' Classification Output crossentropyex
Specify the training options. Specify the solver as 'adam'
and 'GradientThreshold'
as 1. Set the minibatch size to 27 and set the maximum number of epochs to 100.
Because the minibatches are small with short sequences, the CPU is better suited for training. Set 'ExecutionEnvironment'
to 'cpu'
. To train on a GPU, if available, set 'ExecutionEnvironment'
to 'auto'
(the default value).
maxEpochs = 100; miniBatchSize = 27; options = trainingOptions('adam', ... 'ExecutionEnvironment','cpu', ... 'MaxEpochs',maxEpochs, ... 'MiniBatchSize',miniBatchSize, ... 'GradientThreshold',1, ... 'Verbose',false, ... 'Plots','trainingprogress');
Train the LSTM network with the specified training options.
net = trainNetwork(XTrain,YTrain,layers,options);
Load the test set and classify the sequences into speakers.
[XTest,YTest] = japaneseVowelsTestData;
Classify the test data. Specify the same minibatch size used for training.
YPred = classify(net,XTest,'MiniBatchSize',miniBatchSize);
Calculate the classification accuracy of the predictions.
acc = sum(YPred == YTest)./numel(YTest)
acc = 0.9541
imds
— Image datastoreImageDatastore
objectImage datastore, specified as an ImageDatastore
object.
ImageDatastore
allows batch reading of JPG or PNG image files using
prefetching. If you use a custom function for reading the images, then
ImageDatastore
does not prefetch.
Use augmentedImageDatastore
for efficient preprocessing of images for deep
learning including image resizing.
Do not use the readFcn
option of imageDatastore
as this option is usually significantly slower.
ds
— DatastoreDatastore for outofmemory data and preprocessing.
For networks with a single input, the table or cell array returned by the datastore has two columns that specify the network inputs and expected responses, respectively.
For networks with multiple inputs, the datastore must be a
combined or transformed datastore that returns a cell array with
(numInputs
+1) columns containing the predictors and the responses, where
numInputs
is the number of network inputs and
numResponses
is the number of responses. For i
less
than or equal to numInputs
, the i
th element of the cell
array corresponds to the input layers.InputNames(i)
, where
layers
is the layer graph defining the network architecture. The last
column of the cell array corresponds to the responses.
The table below lists the datastores that are directly compatible with
trainNetwork
. You can use other builtin datastores
for training deep learning networks by using the transform
and combine
functions. These functions can convert the data read
from datastores to the table or cell array format required by
trainNetwork
. For more information, see Datastores for Deep Learning.
Type of Datastore  Description 

CombinedDatastore  Horizontally concatenate the data read from two or more underlying datastores. 
TransformedDatastore  Transform batches of read data from an underlying datastore according to your own preprocessing pipeline. 
AugmentedImageDatastore  Apply random affine geometric transformations, including resizing, rotation, reflection, shear, and translation, for training deep neural networks. 
PixelLabelImageDatastore  Apply identical affine geometric transformations to images and corresponding ground truth labels for training semantic segmentation networks (requires Computer Vision Toolbox™). 
RandomPatchExtractionDatastore  Extract pairs of random patches from images or pixel label images (requires Image Processing Toolbox™). You optionally can apply identical random affine geometric transformations to the pairs of patches. 
DenoisingImageDatastore  Apply randomly generated Gaussian noise for training denoising networks (requires Image Processing Toolbox). 
Custom minibatch datastore  Create minibatches of sequence, time series, or text data. For details, see Develop Custom MiniBatch Datastore. 
X
— Image dataImage data, specified as a numeric array. The size of the array depends on the type of image input:
Input  Description 

2D images  A hbywbycbyN numeric array, where h, w, and c are the height, width, and number of channels of the images, respectively, and N is the number of images. 
3D images  A hbywbydbycbyN numeric array, where h, w, d, and c are the height, width, depth, and number of channels of the images, respectively, and N is the number of images. 
If the array contains NaN
s, then they are propagated through
the network.
sequences
— Sequence or time series dataSequence or time series data, specified as an Nby1 cell array of numeric arrays, where N is the number of observations, a numeric array representing a single sequence, or a datastore.
For cell array or numeric array input, the dimensions of the numeric arrays containing the sequences depend on the type of data.
Input  Description 

Vector sequences  cbys matrices, where c is the number of features of the sequences and s is the sequence length. 
2D image sequences  hbywbycbys arrays, where h, w, and c correspond to the height, width, and number of channels of the images, respectively, and s is the sequence length. 
3D image sequences  hbywbydbycbys, where h, w, d, and c correspond to the height, width, depth, and number of channels of the 3D images, respectively, and s is the sequence length. 
For datastore input, the datastore must return data as a cell array of sequences or a table whose first column contains sequences. The dimensions of the sequence data must correspond to the table above.
Y
— ResponsesResponses, specified as a categorical vector of labels, a numeric array, a
cell array of categorical sequences, or cell array of numeric sequences. The
format of Y
depends on the type of task. Responses must
not contain NaN
s.
Task  Format 

Image classification  Nby1 categorical vector of labels, where N is the number of observations. 
Sequencetolabel classification  
Sequencetosequence classification  Nby1 cell array of categorical sequences of labels, where
N is the number of observations. Each sequence has
the same number of time steps as the corresponding input sequence after
applying the 
For sequencetosequence classification problems with one observation,
sequences
can also be a vector. In this case,
Y
must be a categorical sequence of
labels.
Task  Format 

2D image regression 

3D image regression 

Sequencetoone regression  NbyR matrix, where N is the number of sequences and R is the number of responses. 
Sequencetosequence regression  Nby1 cell array of numeric sequences, where N
is the number of sequences. The sequences are matrices with
R rows, where R is the number of
responses. Each sequence has the same number of time steps as the
corresponding input sequence after applying the

For sequencetosequence regression problems with one observation,
sequences
can be a matrix. In this case,
Y
must be a matrix of responses.
Normalizing the responses often helps to stabilize and speed up training of neural networks for regression. For more information, see Train Convolutional Neural Network for Regression.
tbl
— Input datatable
Input data, specified as a table containing predictors in the first column and responses in the remaining column or columns. Each row in the table corresponds to an observation.
The arrangement of predictors and responses in the table columns depends on the type of problem.
Classification
Task  Predictors  Responses 

Image classification 

Categorical label 
Sequencetolabel classification  Absolute or relative file path to a MAT file containing sequence or time series data. The MAT file must contain a time series represented by a matrix with rows corresponding to data points and columns corresponding to time steps. 
Categorical label 
Sequencetosequence classification  Absolute or relative file path to a MAT file. The MAT file must contain a time series represented by a categorical vector, with entries corresponding to labels for each time step. 
For classification problems, if you do not specify
responseName
, then the function, by default, uses
the responses in the second column of tbl
.
Regression
Task  Predictors  Responses 

Image regression 


Sequencetoone regression  Absolute or relative file path to a MAT file containing sequence or time series data. The MAT file must contain a time series represented by a matrix with rows corresponding to data points and columns corresponding to time steps. 

Sequencetosequence regression  Absolute or relative file path to a MAT file. The MAT file must contain a time series represented by a matrix, where rows correspond to responses and columns correspond to time steps. 
For regression problems, if you do not specify
responseName
, then the function, by default, uses
the remaining columns of tbl
. Normalizing the responses
often helps to stabilize and speed up training of neural networks for
regression. For more information, see Train Convolutional Neural Network for Regression.
Responses cannot contain NaN
s. If the predictor data
contains NaN
s, then they are propagated through the
training. However, in most cases, the training fails to converge.
Data Types: table
responseName
— Names of response variables in the input tableNames of the response variables in the input table, specified as a character vector,
cell array of character vectors, or a string array. For problems with one response,
responseName
is the corresponding variable name in
tbl
. For regression problems with multiple response variables,
responseName
is an array of the corresponding variable names in
tbl
.
Data Types: char
 cell
 string
layers
— Network layersLayer
array  LayerGraph
objectNetwork layers, specified as a Layer
array or a LayerGraph
object.
To create a network with all layers connected sequentially, you can use a Layer
array as the input argument. In this case, the returned network is a SeriesNetwork
object.
A directed acyclic graph (DAG) network has a complex structure in which layers can have
multiple inputs and outputs. To create a DAG network, specify the network architecture
as a LayerGraph
object and then use that layer graph as the input argument to
trainNetwork
.
For a list of builtin layers, see List of Deep Learning Layers.
options
— Training optionsTrainingOptionsSGDM
 TrainingOptionsRMSProp
 TrainingOptionsADAM
Training options, specified as a TrainingOptionsSGDM
,
TrainingOptionsRMSProp
, or
TrainingOptionsADAM
object returned by the trainingOptions
function. To
specify solver and other options for network training, use
trainingOptions
.
net
— Trained networkSeriesNetwork
object  DAGNetwork
objectTrained network, returned as a SeriesNetwork
object or a DAGNetwork
object.
If you train the network using a Layer
array as the
layers
input argument, then
net
is a SeriesNetwork
object. If
you train the network using a LayerGraph
object as the input argument, then
net
is a DAGNetwork
object.
info
— Training informationTraining information, returned as a structure, where each field is a numeric vector with one element per training iteration.
For classification problems, info
contains the
following fields:
TrainingLoss
— Loss function
values
TrainingAccuracy
— Training
accuracies
ValidationLoss
— Loss function
values
ValidationAccuracy
— Validation
accuracies
BaseLearnRate
— Learning
rates
For regression problems, info
contains the following fields:
TrainingLoss
— Loss function
values
TrainingRMSE
— Training RMSE
values
ValidationLoss
— Loss function
values
ValidationRMSE
— Validation RMSE
values
BaseLearnRate
— Learning
rates
The structure only contains the fields ValidationLoss
,
ValidationAccuracy
, and
ValidationRMSE
when options
specifies validation data. The 'ValidationFrequency'
option of trainingOptions
determines which iterations
the software calculates validation metrics. For iterations when the software
does not calculate validation metrics, the corresponding values in the
structure are NaN
.
Deep Learning
Toolbox™ enables you to save networks as .mat files after each epoch during training.
This periodic saving is especially useful when you have a large network or a large data set,
and training takes a long time. If the training is interrupted for some reason, you can
resume training from the last saved checkpoint network. If you want
trainNetwork
to save checkpoint networks, then you must specify the
name of the path by using the 'CheckpointPath'
namevalue pair argument
of trainingOptions
. If the path that you specify does not exist, then
trainingOptions
returns an error.
trainNetwork
automatically assigns unique names to checkpoint network
files. In the example name,
net_checkpoint__351__2018_04_12__18_09_52.mat
, 351 is the iteration
number, 2018_04_12
is the date, and 18_09_52
is the
time at which trainNetwork
saves the network. You can load a checkpoint
network file by doubleclicking it or using the load command at the command line. For
example:
load net_checkpoint__351__2018_04_12__18_09_52.mat
trainNetwork
. For example:trainNetwork(XTrain,YTrain,net.Layers,options)
All functions for deep learning training,
prediction, and validation in Deep Learning
Toolbox perform computations using singleprecision, floatingpoint arithmetic. Functions
for deep learning include trainNetwork
, predict
, classify
, and
activations
. The
software uses singleprecision arithmetic when you train networks using both CPUs and
GPUs.
[1] Kudo, M., J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using PassingThrough Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pp. 1103–1111.
[2] Kudo, M., J. Toyama, and M. Shimbo. Japanese Vowels Data Set. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
To run computation in parallel, set the 'ExecutionEnvironment'
option to 'multigpu'
or 'parallel'
.
Use trainingOptions
to set the
'ExecutionEnvironment'
and supply the
options
to trainNetwork
. If you do not
set 'ExecutionEnvironment'
, then
trainNetwork
runs on a GPU if available.
For details, see Scale Up Deep Learning in Parallel and in the Cloud.
DAGNetwork
 LayerGraph
 SeriesNetwork
 analyzeNetwork
 assembleNetwork
 classify
 predict
 trainingOptions
A modified version of this example exists on your system. Do you want to open this version instead? (zh_CN)
您点击了调用以下 MATLAB 命令的链接:
Web 浏览器不支持 MATLAB 命令。请在 MATLAB 命令窗口中直接输入该命令以运行它。
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
Select web siteYou can also select a web site from the following list:
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.