Main Content

SeriesNetwork

(不推荐)用于深度学习的串行网络

不推荐使用 SeriesNetwork 对象。请改用 dlnetwork 对象。有关详细信息,请参阅版本历史记录

说明

串行网络是一种用于深度学习的神经网络,具有依次排列的各个层。它有一个输入层和一个输出层。

创建对象

创建 SeriesNetwork 对象有多种方法:

注意

要了解其他预训练网络,如 googlenetresnet50,请参阅预训练的深度神经网络

属性

全部展开

此 属性 为只读。

网络层,指定为 Layer 数组。

此 属性 为只读。

输入层的名称,指定为字符向量元胞数组。

数据类型: cell

此 属性 为只读。

输出层的名称,指定为字符向量元胞数组。

数据类型: cell

对象函数

activations(不推荐)计算深度学习网络层激活值
classify(Not recommended) Classify data using trained deep learning neural network
predict(Not recommended) Predict responses using trained deep learning neural network
predictAndUpdateState(Not recommended) Predict responses using a trained recurrent neural network and update the network state
classifyAndUpdateState(Not recommended) Classify data using a trained recurrent neural network and update the network state
resetStateReset state parameters of neural network
plot绘制神经网络架构

示例

全部折叠

训练用于图像分类的网络

将数据作为 ImageDatastore 对象加载。

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet', ...
    'nndemos','nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

该数据存储包含 10,000 个数字 0 至 9 的合成图像。这些图像是通过对使用不同字体创建的数字图像应用随机变换生成的。每个数字图像为 28×28 像素。该数据存储包含的每个类别都有相同数量的图像。

显示数据存储中的部分图像。

figure
numImages = 10000;
perm = randperm(numImages,20);
for i = 1:20
    subplot(4,5,i);
    imshow(imds.Files{perm(i)});
    drawnow;
end

Figure contains 20 axes objects. Axes object 1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

划分数据存储,使训练集中的每个类别包含 750 个图像,测试集包含对应每个标签的其余图像。

numTrainingFiles = 750;
[imdsTrain,imdsTest] = splitEachLabel(imds,numTrainingFiles, ...
    'randomize');

splitEachLabeldigitData 中的图像文件拆分为两个新的数据存储,imdsTrainimdsTest

定义卷积神经网络架构。

layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(5,20)
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

将选项设置为带动量的随机梯度下降的默认设置。将最大训练轮数设置为 20,以 0.0001 的初始学习率开始训练。

options = trainingOptions('sgdm', ...
    'MaxEpochs',20,...
    'InitialLearnRate',1e-4, ...
    'Verbose',false, ...
    'Plots','training-progress');

训练网络。

net = trainNetwork(imdsTrain,layers,options);

Figure Training Progress (15-Aug-2023 20:45:20) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 6 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel Accuracy (%) contains 6 objects of type patch, text, line.

基于未用于训练网络的测试集运行经过训练的网络,并预测图像标签(数字)。

YPred = classify(net,imdsTest);
YTest = imdsTest.Labels;

计算准确度。准确度是测试数据中与来自 classify 的分类匹配的真实标签数量与测试数据中图像数量的比率。

accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9416

扩展功能

版本历史记录

在 R2016a 中推出

全部折叠

R2024a: 不推荐

从 R2024a 开始,不推荐使用 SeriesNetwork 对象,请改用 dlnetwork 对象。

目前没有停止支持 SeriesNetwork 对象的计划。但是,推荐改用 dlnetwork 对象,此类对象具有以下优势:

  • dlnetwork 对象是一种统一的数据类型,支持网络构建、预测、内置训练、可视化、压缩、验证和自定义训练循环。

  • dlnetwork 对象支持更广泛的网络架构,您可以创建或从外部平台导入这些网络架构。

  • trainnet 函数支持 dlnetwork 对象,这使您能够轻松指定损失函数。您可以从内置损失函数中进行选择或指定自定义损失函数。

  • 使用 dlnetwork 对象进行训练和预测通常比使用 LayerGraphtrainNetwork 工作流更快。

要将已训练的 SeriesNetwork 对象转换为 dlnetwork 对象,请使用 dag2dlnetwork 函数。

下表显示了 SeriesNetwork 对象的一些典型用法,以及如何更新您的代码以改用 dlnetwork 对象函数。

不推荐推荐
Y = predict(net,X);
Y = minibatchpredict(net,X);
Y = classify(net,X);
scores = minibatchpredict(net,X);
Y = scores2label(scores,classNames);
plot(net);
plot(net);
Y = activations(net,X,layerName);
Y = predict(net,X,Outputs=layerName);
[net,Y] = predictAndUpdateState(net,X);
[Y,state] = predict(net,X);
net.State = state;
[net,Y] = classifyAndUpdateState(net,X);
[scores,state] = predict(net,X);
Y = scores2label(scores,classNames);
net.State = state;