主要内容

训练语义分割网络

加载训练数据。

dataSetDir = fullfile(toolboxdir("vision"),"visiondata","triangleImages");
imageDir = fullfile(dataSetDir,"trainingImages");
labelDir = fullfile(dataSetDir,"trainingLabels");

为图像创建一个图像数据存储。

imds = imageDatastore(imageDir);

为真实值像素标签创建一个 pixelLabelDatastore

classNames = ["triangle" "background"];
labelIDs   = [255 0];
pxds = pixelLabelDatastore(labelDir,classNames,labelIDs);

可视化训练图像和真实值像素标签。

I = read(imds);
C = read(pxds);

I = imresize(I,5,"nearest");
L = imresize(uint8(C{1}),5,"nearest");
imshowpair(I,L,"montage")

Figure contains an axes object. The hidden axes object contains an object of type image.

组合图像和像素标签数据存储用于训练。

trainingData = combine(imds,pxds);

创建一个语义分割网络。此网络使用基于下采样和上采样设计的简单语义分割网络。

numFilters = 64;
filterSize = 3;
numClasses = 2;
layers = [
    imageInputLayer([32 32 1])
    convolution2dLayer(filterSize,numFilters,Padding=1)
    reluLayer()
    maxPooling2dLayer(2,Stride=2)
    convolution2dLayer(filterSize,numFilters,Padding=1)
    reluLayer()
    transposedConv2dLayer(4,numFilters,Stride=2,Cropping=1);
    convolution2dLayer(1,numClasses);
    softmaxLayer()
    ];

设置训练选项。

opts = trainingOptions("sgdm", ...
    InitialLearnRate=1e-3, ...
    MaxEpochs=100, ...
    MiniBatchSize=64);

定义适合像素分类的损失函数。

function loss = modelLoss(Y,T)
    mask = ~isnan(T);
    T(isnan(T)) = 0;
    loss = crossentropy(Y,T,Mask=mask,NormalizationFactor="mask-included");
end

训练网络。

net = trainnet(trainingData,layers,@modelLoss,opts);
    Iteration    Epoch    TimeElapsed    LearnRate    TrainingLoss
    _________    _____    ___________    _________    ____________
            1        1       00:00:13        0.001         0.65456
           50       17       00:00:52        0.001        0.071766
          100       34       00:01:07        0.001         0.04846
          150       50       00:01:24        0.001        0.028925
          200       67       00:01:54        0.001        0.028831
          250       84       00:02:30        0.001        0.029716
          300      100       00:03:05        0.001        0.019433
Training stopped: Max epochs completed

读取并显示测试图像。

testImage = imread("triangleTest.jpg");
imshow(testImage)

Figure contains an axes object. The hidden axes object contains an object of type image.

分割测试图像并显示结果。

C = semanticseg(testImage,net,Classes=classNames);
B = labeloverlay(testImage,C);
imshow(B)

Figure contains an axes object. The hidden axes object contains an object of type image.

另请参阅

(Deep Learning Toolbox)

主题