Main Content

本页翻译不是最新的。点击此处可查看最新英文版本。

训练条件生成对抗网络 (CGAN)

此示例说明如何训练条件生成对抗网络来生成图像。

生成对抗网络 (GAN) 是一种深度学习网络,它能够生成与输入训练数据具有相似特征的数据。

一个 GAN 由两个一起训练的网络组成:

  1. 生成器 - 给定随机值向量作为输入,此网络可生成与训练数据具有相同结构的数据。

  2. 判别器 - 给定包含来自训练数据和来自生成器的生成数据的观测值的数据批量,此网络尝试将观测值划分为 "real""generated"

条件生成对抗网络 (CGAN) 是一种类型的 GAN,它也在训练过程中利用标签。

  1. 生成器 - 给定标签和随机数组作为输入,此网络生成与对应于相同标签的训练数据观测值具有相同结构的数据。

  2. 判别器 - 给定包含来自训练数据和来自生成器的生成数据的观测值的标注数据批量,此网络尝试将观测值划分为 "real""generated"

要训练条件 GAN,需要同时训练两个网络以最大化两个网络的性能:

  • 训练生成器以生成“欺骗”判别器的数据。

  • 训练判别器以区分真实数据和生成的数据。

为了最大化生成器的性能,当给定生成的标注数据时,最大化判别器的损失。也就是说,生成器的目标是生成判别器分类为 "real" 的标注数据。

为了最大化判别器的性能,当给定真实数据和生成的标注数据批量时,最小化判别器的损失。即判别器的目标是不被生成器“欺骗”。

理想情况下,这些策略会得到一个能够生成令人信服的对应于输入标签的真实数据的生成器,以及一个已学习到针对每个标签的训练数据特有的强特征表示的判别器。

加载训练数据

下载并提取 Flowers 数据集 [1]。

url = "http://download.tensorflow.org/example_images/flower_photos.tgz";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"flower_dataset.tgz");

imageFolder = fullfile(downloadFolder,"flower_photos");
if ~exist(imageFolder,"dir")
    disp("Downloading Flowers data set (218 MB)...")
    websave(filename,url);
    untar(filename,downloadFolder)
end

创建一个包含花卉照片的图像数据库。

datasetFolder = fullfile(imageFolder);

imds = imageDatastore(datasetFolder,IncludeSubfolders=true,LabelSource="foldernames");

查看类数。

classes = categories(imds.Labels);
numClasses = numel(classes)
numClasses = 5

增强数据以包括随机水平翻转,并将图像大小调整为 64×64。

augmenter = imageDataAugmenter(RandXReflection=true);
augimds = augmentedImageDatastore([64 64],imds,DataAugmentation=augmenter);

定义生成器网络

定义以下双输入网络,在给定大小为 100 的随机向量和对应标签的情况下,该网络生成图像。

此网络:

  • 使用一个全连接层后跟重构运算,将大小为 100 的随机向量转换为 4×4×1024 数组。

  • 将分类标签转换为嵌入向量,并将它们重构为一个 4×4 数组。

  • 沿通道维度串联来自两个输入的结果图像。输出是一个 4×4×1025 数组。

  • 使用一系列带批量归一化和 ReLU 层的转置卷积层,将生成的数组扩增到 64×64×3 数组。

将此网络架构定义为一个层图,并指定以下网络属性。

  • 对于分类输入,使用嵌入维度 50。

  • 对于转置卷积层,指定 5×5 滤波器,每一层的滤波器数量递减,步幅为 2,并对输出进行 "same" 裁剪。

  • 对于最终的转置卷积层,指定与生成图像的三个 RGB 通道对应的三个 5×5 滤波器。

  • 在网络末尾,包括一个 tanh 层。

要投影和重构噪声输入,请使用一个全连接层后跟一个重构运算函数层,该函数层的函数为此示例随附支持文件中提供的 feature2image。要嵌入分类标签,请使用以支持文件形式包含在此示例中的自定义层 embeddingLayer。要访问这些支持文件,请以实时脚本形式打开示例。

numLatentInputs = 100;
embeddingDimension = 50;
numFilters = 64;

filterSize = 5;
projectionSize = [4 4 1024];

layersGenerator = [
    featureInputLayer(numLatentInputs)
    fullyConnectedLayer(prod(projectionSize))
    functionLayer(@(X) feature2image(X,projectionSize),Formattable=true)
    concatenationLayer(3,2,Name="cat");
    transposedConv2dLayer(filterSize,4*numFilters)
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same")
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same")
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same")
    tanhLayer];

lgraphGenerator = layerGraph(layersGenerator);

layers = [
    featureInputLayer(1)
    embeddingLayer(embeddingDimension,numClasses)
    fullyConnectedLayer(prod(projectionSize(1:2)))
    functionLayer(@(X) feature2image(X,[projectionSize(1:2) 1]),Formattable=true,Name="emb_reshape")];

lgraphGenerator = addLayers(lgraphGenerator,layers);
lgraphGenerator = connectLayers(lgraphGenerator,"emb_reshape","cat/in2");

要使用自定义训练循环训练网络并支持自动微分,请将层图转换为 dlnetwork 对象。

netG = dlnetwork(lgraphGenerator)
netG = 
  dlnetwork with properties:

         Layers: [19×1 nnet.cnn.layer.Layer]
    Connections: [18×2 table]
     Learnables: [19×3 table]
          State: [6×3 table]
     InputNames: {'input'  'input_1'}
    OutputNames: {'layer_2'}
    Initialized: 1

  View summary with summary.

定义判别器网络

定义以下双输入网络,在给定一组图像和对应标签的情况下,该网络对真实图像和生成的 64×64 图像进行分类。

创建一个网络,该网络将 64×64×1 图像和对应的标签作为输入,并使用一系列具有批量归一化和泄漏 ReLU 层的卷积层输出一个标量预测分数。使用丢弃法给输入图像添加噪声。

  • 对于丢弃层,指定丢弃概率为 0.75。

  • 对于卷积层,指定 5×5 滤波器,每一层的滤波器数量递增。同时指定步幅为 2 以及在每页上对输出进行填充。

  • 对于泄漏 ReLU 层,指定 0.2 的尺度。

  • 对于最终层,指定具有一个 4×4 滤波器的卷积层。

dropoutProb = 0.75;
numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersDiscriminator = [
    imageInputLayer(inputSize,Normalization="none")
    dropoutLayer(dropoutProb)
    concatenationLayer(3,2,Name="cat")
    convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same")
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(4,1)];

lgraphDiscriminator = layerGraph(layersDiscriminator);

layers = [
    featureInputLayer(1)
    embeddingLayer(embeddingDimension,numClasses)
    fullyConnectedLayer(prod(inputSize(1:2)))
    functionLayer(@(X) feature2image(X,[inputSize(1:2) 1]),Formattable=true,Name="emb_reshape")];

lgraphDiscriminator = addLayers(lgraphDiscriminator,layers);
lgraphDiscriminator = connectLayers(lgraphDiscriminator,"emb_reshape","cat/in2");

要使用自定义训练循环训练网络并支持自动微分,请将层图转换为 dlnetwork 对象。

netD = dlnetwork(lgraphDiscriminator)
netD = 
  dlnetwork with properties:

         Layers: [19×1 nnet.cnn.layer.Layer]
    Connections: [18×2 table]
     Learnables: [19×3 table]
          State: [6×3 table]
     InputNames: {'imageinput'  'input'}
    OutputNames: {'conv_5'}
    Initialized: 1

  View summary with summary.

定义模型损失函数

创建在示例的模型损失函数部分列出的函数 modelLoss,该函数接受生成器和判别器网络、小批量输入数据和随机值数组作为输入,并返回损失关于网络中可学习参数的梯度和一个由生成的图像组成的数组。

指定训练选项

使用小批量大小 128 进行 500 轮训练。

numEpochs = 500;
miniBatchSize = 128;

指定 Adam 优化的选项。对于这两个网络,请使用:

  • 学习率为 0.0002

  • 梯度衰减因子为 0.5

  • 梯度平方衰减因子为 0.999

learnRate = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

每 100 次迭代更新一次训练进度图。

validationFrequency = 100;

如果判别器过快地学会了如何判别真实图像和生成的图像,则生成器可能无法进行训练。为了更好地平衡判别器和生成器的学习,随机翻转一部分真实图像的标签。将翻转因子指定为 0.5。

flipFactor = 0.5;

训练模型

使用自定义训练循环训练模型。在每次迭代中遍历训练数据并更新网络参数。为了监控训练进度,以保留的随机值数组作为生成器输入来显示一批生成图像,同时显示网络分数。

使用 minibatchqueue 在训练过程中处理和管理小批量图像。对于每个小批量:

  • 使用自定义小批量预处理函数 preprocessMiniBatch(在此示例末尾定义)在 [-1,1] 范围内重新缩放图像。

  • 丢弃观测值少于 128 个的任何不完整小批量。

  • 用维度标签 "SSCB"(空间、空间、通道、批量)格式化图像数据。

  • 使用维度标签 "BC"(批量、通道)格式化标签数据。

  • 在 GPU 上(如果有)进行训练。当 minibatchqueueOutputEnvironment 选项为 "auto" 时,minibatchqueue 将每个输出转换为 gpuArray(如果 GPU 可用)。使用 GPU 需要 Parallel Computing Toolbox™ 和支持的 GPU 设备。有关受支持设备的信息,请参阅GPU Computing Requirements (Parallel Computing Toolbox)

默认情况下,minibatchqueue 对象将数据转换为基础类型为 singledlarray 对象。

augimds.MiniBatchSize = miniBatchSize;
executionEnvironment = "auto";

mbq = minibatchqueue(augimds, ...
    MiniBatchSize=miniBatchSize, ...
    PartialMiniBatch="discard", ...
    MiniBatchFcn=@preprocessData, ...
    MiniBatchFormat=["SSCB" "BC"], ...
    OutputEnvironment=executionEnvironment);    

初始化 Adam 优化器的参数。

velocityD = [];
trailingAvgG = [];
trailingAvgSqG = [];
trailingAvgD = [];
trailingAvgSqD = [];

为了监控训练进度,创建一批包含 25 个随机向量的留出数据,以及一组由标签值 1 到 5 (对应于类)重复五次所得的标签数据。

numValidationImagesPerClass = 5;
ZValidation = randn(numLatentInputs,numValidationImagesPerClass*numClasses,"single");

TValidation = single(repmat(1:numClasses,[1 numValidationImagesPerClass]));

将数据转换为 dlarray 对象,并指定维度标签 "CB"(通道、批量)。

ZValidation = dlarray(ZValidation,"CB");
TValidation = dlarray(TValidation,"CB");

对于 GPU 训练,将数据转换为 gpuArray 对象。

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    ZValidation = gpuArray(ZValidation);
    TValidation = gpuArray(TValidation);
end

要跟踪生成器和判别器的分数,请使用 trainingProgressMonitor 对象。计算监视器对象的总迭代次数。

numObservationsTrain = numel(imds.Files);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;

初始化 TrainingProgressMonitor 对象。由于计时器在您创建监视器对象时启动,请确保您创建的对象靠近训练循环。

monitor = trainingProgressMonitor( ...
    Metrics=["GeneratorScore","DiscriminatorScore"], ...
    Info=["Epoch","Iteration"], ...
    XLabel="Iteration");

groupSubPlot(monitor,Score=["GeneratorScore","DiscriminatorScore"])

训练条件 GAN。对于每轮训练,对数据进行乱序处理,并遍历小批量数据。

对于每个小批量:

  • 如果 TrainingProgressMonitor 对象的 Stop 属性为 true,则停止。当您点击停止按钮时,Stop 属性会更改为 true

  • 使用 dlfevalmodelLoss 函数,评估损失对于可学习参数的梯度、生成器状态和网络分数。

  • 使用 adamupdate 函数更新网络参数。

  • 绘制两个网络的分数。

  • 在每 validationFrequency 次迭代后,显示一批基于固定保留生成器输入的生成图像。

训练可能需要一些时间来运行。

epoch = 0;
iteration = 0;

% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
    epoch = epoch + 1;

    % Reset and shuffle data.
    shuffle(mbq);

    % Loop over mini-batches.
    while hasdata(mbq) && ~monitor.Stop
        iteration = iteration + 1;

        % Read mini-batch of data.
        [X,T] = next(mbq);

        % Generate latent inputs for the generator network. Convert to
        % dlarray and specify the dimension labels "CB" (channel, batch).
        % If training on a GPU, then convert latent inputs to gpuArray.
        Z = randn(numLatentInputs,miniBatchSize,"single");
        Z = dlarray(Z,"CB");
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            Z = gpuArray(Z);
        end

        % Evaluate the gradients of the loss with respect to the learnable
        % parameters, the generator state, and the network scores using
        % dlfeval and the modelLoss function.
        [~,~,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
            dlfeval(@modelLoss,netG,netD,X,T,Z,flipFactor);
        netG.State = stateG;

        % Update the discriminator network parameters.
        [netD,trailingAvgD,trailingAvgSqD] = adamupdate(netD, gradientsD, ...
            trailingAvgD, trailingAvgSqD, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Update the generator network parameters.
        [netG,trailingAvgG,trailingAvgSqG] = ...
            adamupdate(netG, gradientsG, ...
            trailingAvgG, trailingAvgSqG, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Every validationFrequency iterations, display batch of generated images using the
        % held-out generator input.
        if mod(iteration,validationFrequency) == 0 || iteration == 1
            
            % Generate images using the held-out generator input.
            XGeneratedValidation = predict(netG,ZValidation,TValidation);
            
            % Tile and rescale the images in the range [0 1].
            I = imtile(extractdata(XGeneratedValidation), ...
                GridSize=[numValidationImagesPerClass numClasses]);
            I = rescale(I);
            
            % Display the images.
            image(I)
            xticklabels([]);
            yticklabels([]);
            title("Generated Images");
        end

        % Update the training progress monitor.
        recordMetrics(monitor,iteration, ...
            GeneratorScore=scoreG, ...
            DiscriminatorScore=scoreD);

        updateInfo(monitor,Epoch=epoch,Iteration=iteration);
        monitor.Progress = 100*iteration/numIterations;
    end
end

此时,判别器已学会在生成的图像中识别真实图像的强特征表示。顺带,生成器已学会类似的强特征表示,能够生成类似于训练数据的图像。每列对应于一个类。

训练图显示生成器和判别器网络的分数。要了解有关如何解释网络分数的详细信息,请参阅Monitor GAN Training Progress and Identify Common Failure Modes

生成新图像

要生成某特定类的新图像,请将生成器上的 predict 函数与一个 dlarray 对象结合使用,该对象包含一批随机向量和一个由对应于所需类的标签组成的数组。将数据转换为 dlarray 对象,并指定维度标签 "CB"(通道、批量)。要进行 GPU 预测,请将数据转换为 gpuArray 对象。要一起显示图像,请使用 imtile 函数,并使用 rescale 函数重新缩放图像。

创建一个由对应于第一个类的 36 个随机值向量组成的数组。

numObservationsNew = 36;
idxClass = 1;
ZNew = randn(numLatentInputs,numObservationsNew,"single");
TNew = repmat(single(idxClass),[1 numObservationsNew]);

将数据转换为具有维度标签 "SSCB"(空间、空间、通道、批量)的 dlarray 对象。

ZNew = dlarray(ZNew,"CB");
TNew = dlarray(TNew,"CB");

要使用 GPU 生成图像,还要将数据转换为 gpuArray 对象。

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    ZNew = gpuArray(ZNew);
    TNew = gpuArray(TNew);
end

通过生成器网络使用 predict 函数生成图像。

XGeneratedNew = predict(netG,ZNew,TNew);

在绘图中显示生成的图像。

figure
I = imtile(extractdata(XGeneratedNew));
I = rescale(I);
imshow(I)
title("Class: " + classes(idxClass))

此处,生成器网络根据指定的类生成图像。

模型损失函数

函数 modelLoss 接受生成器和判别器 dlnetwork 对象 netGnetD、小批量输入数据 X、对应的标签 T 以及随机值数组 Z 作为输入,并返回损失关于网络中可学习参数的梯度、生成器状态和网络分数。

如果判别器过快地学会了如何判别真实图像和生成的图像,则生成器可能无法进行训练。为了更好地平衡判别器和生成器的学习,随机翻转一部分真实图像的标签。

function [lossG,lossD,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
    modelLoss(netG,netD,X,T,Z,flipFactor)

% Calculate the predictions for real data with the discriminator network.
YReal = forward(netD,X,T);

% Calculate the predictions for generated data with the discriminator network.
[XGenerated,stateG] = forward(netG,Z,T);
YGenerated = forward(netD,XGenerated,T);

% Calculate probabilities.
probGenerated = sigmoid(YGenerated);
probReal = sigmoid(YReal);

% Calculate the generator and discriminator scores.
scoreG = mean(probGenerated);
scoreD = (mean(probReal) + mean(1-probGenerated)) / 2;

% Flip labels.
numObservations = size(YReal,4);
idx = randperm(numObservations,floor(flipFactor * numObservations));
probReal(:,:,:,idx) = 1 - probReal(:,:,:,idx);

% Calculate the GAN loss.
[lossG, lossD] = ganLoss(probReal,probGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsG = dlgradient(lossG,netG.Learnables,RetainData=true);
gradientsD = dlgradient(lossD,netD.Learnables);

end

GAN 损失函数

生成器的目标是生成判别器分类为 "real" 的数据。为了最大化判别器将生成器生成的图像判别为真实图像的概率,最小化负对数似然函数。

给定判别器的输出 Y

  • Yˆ=σ(Y) 是输入图像属于 "real" 类的概率。

  • 1-Yˆ 是输入图像属于 "generated" 类的概率。

请注意,sigmoid 运算 σ 发生在 modelLoss 函数中。生成器的损失函数由下式给出

lossGenerator=-mean(log(YˆGenerated)),

其中 YˆGenerated 包含生成图像的判别器输出概率。

判别器的目标是不被生成器“欺骗”。为了最大化判别器成功判别真实图像和生成图像的概率,最小化对应的负对数似然函数之和。判别器的损失函数由下式给出

lossDiscriminator=-mean(log(YˆReal))-mean(log(1-YˆGenerated)),

其中 YˆReal 包含真实图像的判别器输出概率。

function [lossG, lossD] = ganLoss(scoresReal,scoresGenerated)

% Calculate losses for the discriminator network.
lossGenerated = -mean(log(1 - scoresGenerated));
lossReal = -mean(log(scoresReal));

% Combine the losses for the discriminator network.
lossD = lossReal + lossGenerated;

% Calculate the loss for the generator network.
lossG = -mean(log(scoresGenerated));

end

小批量预处理函数

preprocessMiniBatch 函数使用以下步骤预处理数据:

  1. 从输入元胞数组中提取图像和标签数据,并将它们串联成数值数组。

  2. 将图像重新缩放到 [-1,1] 范围内。

function [X,T] = preprocessData(XCell,TCell)

% Extract image data from cell and concatenate
X = cat(4,XCell{:});

% Extract label data from cell and concatenate
T = cat(1,TCell{:});

% Rescale the images in the range [-1 1].
X = rescale(X,-1,1,InputMin=0,InputMax=255);

end

参考资料

另请参阅

| | | | | |

相关主题