主要内容

本页翻译不是最新的。点击此处可查看最新英文版本。

SeriesNetwork

(不推荐)用于深度学习的串行网络

不推荐使用 SeriesNetwork 对象。请改用 dlnetwork 对象。有关详细信息,请参阅版本历史记录

说明

串行网络是一种用于深度学习的神经网络,具有依次排列的各个层。它有一个输入层和一个输出层。

创建对象

创建 SeriesNetwork 对象有多种方法:

注意

要了解其他预训练网络,如 googlenetresnet50,请参阅预训练的深度神经网络

属性

全部展开

此 属性 为只读。

网络层,指定为 Layer 数组。

此 属性 为只读。

输入层的名称,指定为字符向量元胞数组。

数据类型: cell

此 属性 为只读。

输出层的名称,指定为字符向量元胞数组。

数据类型: cell

对象函数

activations(不推荐)计算深度学习网络层激活值
classify(不推荐)使用经过训练的深度学习神经网络对数据进行分类
predict(不推荐)使用经过训练的深度学习神经网络预测响应
predictAndUpdateState(Not recommended) Predict responses using a trained recurrent neural network and update the network state
classifyAndUpdateState(Not recommended) Classify data using a trained recurrent neural network and update the network state
resetStateReset state parameters of neural network
plot绘制神经网络架构

示例

全部折叠

训练用于图像分类的网络。

将数据作为 ImageDatastore 对象加载。

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet', ...
    'nndemos','nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

该数据存储包含 10,000 个数字 0 至 9 的合成图像。这些图像是通过对使用不同字体创建的数字图像应用随机变换生成的。每个数字图像为 28×28 像素。该数据存储包含的每个类别都有相同数量的图像。

显示数据存储中的部分图像。

figure
numImages = 10000;
perm = randperm(numImages,20);
for i = 1:20
    subplot(4,5,i);
    imshow(imds.Files{perm(i)});
    drawnow;
end

Figure contains 20 axes objects. Axes object 1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

划分数据存储,使训练集中的每个类别包含 750 个图像,测试集包含对应每个标签的其余图像。

numTrainingFiles = 750;
[imdsTrain,imdsTest] = splitEachLabel(imds,numTrainingFiles, ...
    'randomize');

splitEachLabeldigitData 中的图像文件拆分为两个新的数据存储,imdsTrainimdsTest

定义卷积神经网络架构。

layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(5,20)
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

将选项设置为带动量的随机梯度下降的默认设置。将最大训练轮数设置为 20,以 0.0001 的初始学习率开始训练。

options = trainingOptions('sgdm', ...
    'MaxEpochs',20,...
    'InitialLearnRate',1e-4, ...
    'Verbose',false, ...
    'Plots','training-progress');

训练网络。

net = trainNetwork(imdsTrain,layers,options);

Figure Training Progress (15-Aug-2023 20:45:20) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 6 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel Accuracy (%) contains 6 objects of type patch, text, line.

基于未用于训练网络的测试集运行经过训练的网络,并预测图像标签(数字)。

YPred = classify(net,imdsTest);
YTest = imdsTest.Labels;

计算准确度。准确度是测试数据中与来自 classify 的分类匹配的真实标签数量与测试数据中图像数量的比率。

accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9416

扩展功能

全部展开

版本历史记录

在 R2016a 中推出

全部折叠