Main Content

本页的翻译已过时。点击此处可查看最新英文版本。

训练生成对抗网络 (GAN)

此示例说明如何训练生成对抗网络 (GAN) 来生成图像。

生成对抗网络 (GAN) 是一种深度学习网络,它能够生成与真实输入数据具有相似特征的数据。

一个 GAN 由两个一起训练的网络组成:

  1. 生成器 - 给定随机值(潜在输入)向量作为输入,此网络可生成与训练数据具有相同结构的数据。

  2. 判别器 - 给定包含来自训练数据和来自生成器的生成数据的观测值的数据批量,此网络尝试将观测值划分为“真实值”或“生成值”。

要训练 GAN,需要同时训练两个网络以最大化两个网络的性能:

  • 训练生成器以生成“欺骗”判别器的数据。

  • 训练判别器以区分真实数据和生成的数据。

为了优化生成器的性能,当给定生成的数据时,最大化判别器的损失。也就是说,生成器的目标是生成判别器判别为“真实”的数据。

为了优化判别器的性能,当给定真实数据和生成的数据批量时,最小化判别器的损失。即判别器的目标是不被生成器“欺骗”。

理想情况下,这些策略会得到能够生成令人信服的真实数据的生成器,以及已学习到训练数据特有的强特征表示的判别器。

加载训练数据

下载并提取 Flowers 数据集 [1]。

url = 'http://download.tensorflow.org/example_images/flower_photos.tgz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'flower_dataset.tgz');

imageFolder = fullfile(downloadFolder,'flower_photos');
if ~exist(imageFolder,'dir')
    disp('Downloading Flowers data set (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

创建一个包含花卉照片的图像数据库。

datasetFolder = fullfile(imageFolder);

imds = imageDatastore(datasetFolder, ...
    'IncludeSubfolders',true);

增强数据以包括随机水平翻转,并将图像大小调整为 64×64。

augmenter = imageDataAugmenter('RandXReflection',true);
augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);    

定义生成器网络

定义以下网络架构,该架构从 1×1×100 大小的随机值数组生成图像:

此网络:

  • 使用投影和重构层将 1×1×100 噪声数组转换为 7×7×128 数组。

  • 使用一系列带批量归一化和 ReLU 层的转置卷积层,将生成的数组扩增到 64×64×3 数组。

将此网络架构定义为一个层次图,并指定以下网络属性。

  • 对于转置卷积层,指定 5×5 滤波器,每一层的滤波器数量递减,步幅为 2,并在每条边裁剪输出。

  • 对于最终的转置卷积层,指定与生成图像的三个 RGB 通道对应的三个 5×5 滤波器,以及前一层的输出大小。

  • 在网络末尾,包括一个 tanh 层。

要投影和重构噪声输入,请使用自定义层 projectAndReshapeLayer,它以支持文件的形式附加到此示例中。projectAndReshapeLayer 层使用一个全连接操作来扩增输入,并将输出重构为指定的大小。

filterSize = 5;
numFilters = 64;
numLatentInputs = 100;

projectionSize = [4 4 512];

layersGenerator = [
    imageInputLayer([1 1 numLatentInputs],'Normalization','none','Name','in')
    projectAndReshapeLayer(projectionSize,numLatentInputs,'proj');
    transposedConv2dLayer(filterSize,4*numFilters,'Name','tconv1')
    batchNormalizationLayer('Name','bnorm1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping','same','Name','tconv2')
    batchNormalizationLayer('Name','bnorm2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping','same','Name','tconv3')
    batchNormalizationLayer('Name','bnorm3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping','same','Name','tconv4')
    tanhLayer('Name','tanh')];

lgraphGenerator = layerGraph(layersGenerator);

要使用自定义训练循环训练网络并支持自动微分,请将层次图转换为 dlnetwork 对象。

dlnetGenerator = dlnetwork(lgraphGenerator);

定义判别器网络

定义以下网络,它对真实图像和生成的 64×64 图像进行分类。

创建一个网络,该网络接受 64×64×3 图像,并使用一系列具有批量归一化和泄漏 ReLU 层的卷积层返回一个标量预测分数。使用丢弃法给输入图像添加噪声。

  • 对于丢弃层,指定丢弃概率为 0.5。

  • 对于卷积层,指定 5×5 滤波器,每一层的滤波器数量递增。同时指定步幅为 2 以及对输出进行填充。

  • 对于泄漏 ReLU 层,指定 0.2 的尺度。

  • 对于最终层,指定具有一个 4×4 滤波器的卷积层。

要输出 [0,1] 范围内的概率,请使用模型梯度函数中的 sigmoid 函数。

dropoutProb = 0.5;
numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersDiscriminator = [
    imageInputLayer(inputSize,'Normalization','none','Name','in')
    dropoutLayer(0.5,'Name','dropout')
    convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding','same','Name','conv1')
    leakyReluLayer(scale,'Name','lrelu1')
    convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding','same','Name','conv2')
    batchNormalizationLayer('Name','bn2')
    leakyReluLayer(scale,'Name','lrelu2')
    convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding','same','Name','conv3')
    batchNormalizationLayer('Name','bn3')
    leakyReluLayer(scale,'Name','lrelu3')
    convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding','same','Name','conv4')
    batchNormalizationLayer('Name','bn4')
    leakyReluLayer(scale,'Name','lrelu4')
    convolution2dLayer(4,1,'Name','conv5')];

lgraphDiscriminator = layerGraph(layersDiscriminator);

要使用自定义训练循环训练网络并支持自动微分,请将层次图转换为 dlnetwork 对象。

dlnetDiscriminator = dlnetwork(lgraphDiscriminator);

定义模型梯度、损失函数和分数

创建在示例的模型梯度函数部分列出的函数 modelGradients,该函数接受生成器和判别器网络、小批量输入数据、随机值数组和翻转因子作为输入,并返回损失关于网络中可学习参数的梯度和两个网络的分数。

指定训练选项

使用小批量大小 128 进行 500 轮训练。对于较大的数据集,您可能不需要进行这么多轮训练。

numEpochs = 500;
miniBatchSize = 128;

指定 Adam 优化的选项。对于两个网络,都指定

  • 学习率为 0.0002

  • 梯度衰减因子为 0.5

  • 梯度平方衰减因子为 0.999

learnRate = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

如果判别器过快地学会了如何判别真实图像和生成的图像,则生成器可能无法训练。为了更好地平衡判别器和生成器的学习,请通过随机翻转标签向真实数据添加噪声。

指定翻转 30% 的真实标签。这意味着总标签数的 15% 在训练期间翻转。请注意,这不会减损生成器,因为所有生成的图像仍被正确标注。

flipFactor = 0.3;

每经过 100 次迭代就显示生成的验证图像。

validationFrequency = 100;

训练模型

使用 minibatchqueue 处理和管理小批量图像。对于每个小批量:

  • 使用自定义小批量预处理函数 preprocessMiniBatch(在此示例末尾定义)在 [-1,1] 范围内重新缩放图像。

  • 丢弃观测值少于 128 个的任何不完整小批量。

  • 用维度标签 'SSCB'(空间、空间、通道、批量)格式化图像数据。默认情况下,minibatchqueue 对象将数据转换为基础类型为 singledlarray 对象。

  • 在 GPU 上(如果有)进行训练。当 minibatchqueue'OutputEnvironment' 选项为 "auto" 时,minibatchqueue 将每个输出转换为 gpuArray(如果 GPU 可用)。使用 GPU 需要 Parallel Computing Toolbox™ 和具有 3.0 或更高计算能力的支持 CUDA® 的 NVIDIA® GPU。

augimds.MiniBatchSize = miniBatchSize;

executionEnvironment = "auto";

mbq = minibatchqueue(augimds,...
    'MiniBatchSize',miniBatchSize,...
    'PartialMiniBatch','discard',...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat','SSCB',...
    'OutputEnvironment',executionEnvironment);

使用自定义训练循环训练模型。在每次迭代中循环使用训练数据并更新网络参数。为了监控训练进度,以保留的随机值数组作为生成器输入,显示一批生成图像,同时显示其相关的分数图。

为 Adam 初始化参数。

trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminator = [];
trailingAvgSqDiscriminator = [];

为了监控训练进度,使用一批保留的随机值固定数组作为输入馈送到生成器,显示一批生成图像并绘制其网络分数图。

创建一个由保留的随机值组成的数组。

numValidationImages = 25;
ZValidation = randn(1,1,numLatentInputs,numValidationImages,'single');

将数据转换为 dlarray 对象,并指定维度标签 'SSCB'(空间、空间、通道、批量)。

dlZValidation = dlarray(ZValidation,'SSCB');

对于 GPU 训练,将数据转换为 gpuArray 对象。

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZValidation = gpuArray(dlZValidation);
end

初始化训练进度图。创建一个图窗,并将其调整为两倍宽度。

f = figure;
f.Position(3) = 2*f.Position(3);

为生成的图像和网络分数创建子图。

imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);

为分数图初始化动画线条。

lineScoreGenerator = animatedline(scoreAxes,'Color',[0 0.447 0.741]);
lineScoreDiscriminator = animatedline(scoreAxes, 'Color', [0.85 0.325 0.098]);
legend('Generator','Discriminator');
ylim([0 1])
xlabel("Iteration")
ylabel("Score")
grid on

训练 GAN。对于每轮训练,对数据存储进行乱序处理,并循环使用小批量数据。

对于每个小批量:

  • 使用 dlfevalmodelGradients 函数计算模型梯度。

  • 使用 adamupdate 函数更新网络参数。

  • 绘制两个网络的分数。

  • 在每 validationFrequency 次迭代后,显示一批基于固定保留生成器输入的生成图像。

训练可能需要一些时间来运行。

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Reset and shuffle datastore.
    shuffle(mbq);
    
    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        dlX = next(mbq);
        
        % Generate latent inputs for the generator network. Convert to
        % dlarray and specify the dimension labels 'SSCB' (spatial,
        % spatial, channel, batch). If training on a GPU, then convert
        % latent inputs to gpuArray.
        Z = randn(1,1,numLatentInputs,size(dlX,4),'single');
        dlZ = dlarray(Z,'SSCB');        
        
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
        end
        
        % Evaluate the model gradients and the generator state using
        % dlfeval and the modelGradients function listed at the end of the
        % example.
        [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
            dlfeval(@modelGradients, dlnetGenerator, dlnetDiscriminator, dlX, dlZ, flipFactor);
        dlnetGenerator.State = stateGenerator;
        
        % Update the discriminator network parameters.
        [dlnetDiscriminator,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ...
            adamupdate(dlnetDiscriminator, gradientsDiscriminator, ...
            trailingAvgDiscriminator, trailingAvgSqDiscriminator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Update the generator network parameters.
        [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = ...
            adamupdate(dlnetGenerator, gradientsGenerator, ...
            trailingAvgGenerator, trailingAvgSqGenerator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Every validationFrequency iterations, display batch of generated images using the
        % held-out generator input
        if mod(iteration,validationFrequency) == 0 || iteration == 1
            % Generate images using the held-out generator input.
            dlXGeneratedValidation = predict(dlnetGenerator,dlZValidation);
            
            % Tile and rescale the images in the range [0 1].
            I = imtile(extractdata(dlXGeneratedValidation));
            I = rescale(I);
            
            % Display the images.
            subplot(1,2,1);
            image(imageAxes,I)
            xticklabels([]);
            yticklabels([]);
            title("Generated Images");
        end
        
        % Update the scores plot
        subplot(1,2,2)
        addpoints(lineScoreGenerator,iteration,...
            double(gather(extractdata(scoreGenerator))));
        
        addpoints(lineScoreDiscriminator,iteration,...
            double(gather(extractdata(scoreDiscriminator))));
        
        % Update the title with training progress information.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        title(...
            "Epoch: " + epoch + ", " + ...
            "Iteration: " + iteration + ", " + ...
            "Elapsed: " + string(D))
        
        drawnow
    end
end

此时,判别器已学会在生成的图像中识别真实图像的强特征表示。反过来,生成器已学会类似的强特征表示,能够生成看似真实的数据。

训练图显示生成器和判别器网络的分数。要了解有关如何解释网络分数的详细信息,请参阅Monitor GAN Training Progress and Identify Common Failure Modes

生成新图像

要生成新图像,请使用生成器上的 predict 函数和包含由随机值组成的一批 1×1×100 数组的 dlarray 对象。要一起显示图像,请使用 imtile 函数,并使用 rescale 函数重新缩放图像。

创建一个 dlarray 对象(其中包含随机值组成的 25 个 1×1×100 数组)以输入到生成器网络中。

ZNew = randn(1,1,numLatentInputs,25,'single');
dlZNew = dlarray(ZNew,'SSCB');

要使用 GPU 生成图像,还要将数据转换为 gpuArray 对象。

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZNew = gpuArray(dlZNew);
end

使用 predict 函数以及生成器和输入数据生成新图像。

dlXGeneratedNew = predict(dlnetGenerator,dlZNew);

显示图像。

I = imtile(extractdata(dlXGeneratedNew));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")

模型梯度函数

函数 modelGradients 接受生成器和判别器 dlnetwork 对象 dlnetGeneratordlnetDiscriminator、小批量输入数据 dlX、随机值数组 dlZ 和要翻转的真实标签的百分比 flipFactor 作为输入,并返回损失关于网络中可学习参数的梯度、生成器状态和两个网络的分数。由于判别器输出不在 [0,1] 范围中,modelGradients 应用 sigmoid 函数将其转换为概率。

function [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
    modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlZ, flipFactor)

% Calculate the predictions for real data with the discriminator network.
dlYPred = forward(dlnetDiscriminator, dlX);

% Calculate the predictions for generated data with the discriminator network.
[dlXGenerated,stateGenerator] = forward(dlnetGenerator,dlZ);
dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated);

% Convert the discriminator outputs to probabilities.
probGenerated = sigmoid(dlYPredGenerated);
probReal = sigmoid(dlYPred);

% Calculate the score of the discriminator.
scoreDiscriminator = ((mean(probReal)+mean(1-probGenerated))/2);

% Calculate the score of the generator.
scoreGenerator = mean(probGenerated);

% Randomly flip a fraction of the labels of the real images.
numObservations = size(probReal,4);
idx = randperm(numObservations,floor(flipFactor * numObservations));

% Flip the labels
probReal(:,:,:,idx) = 1-probReal(:,:,:,idx);

% Calculate the GAN loss.
[lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables,'RetainData',true);
gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables);

end

GAN 损失函数和分数

生成器的目标是生成判别器判别为“真实”的数据。为了最大化判别器将生成器生成的图像判别为真实图像的概率,最小化负对数似然函数。

给定判别器的输出 Y

  • Yˆ=σ(Y) 是输入图像属于“真实”类的概率。

  • 1-Yˆ 是输入图像属于“生成”类的概率。

请注意,sigmoid 运算 σ 发生在 modelGradients 函数中。生成器的损失函数由下式给出

lossGenerator=-mean(log(YˆGenerated)),

其中 YˆGenerated 包含生成图像的判别器输出概率。

判别器的目标是不被生成器“欺骗”。为了最大化判别器成功判别真实图像和生成图像的概率,最小化对应的负对数似然函数之和。

判别器的损失函数由下式给出

lossDiscriminator=-mean(log(YˆReal))-mean(log(1-YˆGenerated)),

其中 YˆReal 包含真实图像的判别器输出概率。

要按照从 0 到 1 的范围来度量生成器和判别器实现各自目标的程度,可以使用分数的概念。

生成器分数是生成图像判别器输出的概率平均值:

scoreGenerator=mean(YˆGenerated).

判别器分数是真实图像和生成图像判别器输出的概率平均值:

scoreDiscriminator=12mean(YˆReal)+12mean(1-YˆGenerated).

分数与损失成反比,但实际上包含相同的信息。

function [lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated)

% Calculate the loss for the discriminator network.
lossDiscriminator =  -mean(log(probReal)) -mean(log(1-probGenerated));

% Calculate the loss for the generator network.
lossGenerator = -mean(log(probGenerated));

end

小批量预处理函数

preprocessMiniBatch 函数使用以下步骤预处理数据:

  1. 从传入的元胞数组中提取图像数据,并串联成一个数值数组。

  2. 将图像重新缩放到 [-1,1] 范围内。

function X = preprocessMiniBatch(data)
    % Concatenate mini-batch
    X = cat(4,data{:});
    
    % Rescale the images in the range [-1 1].
    X = rescale(X,-1,1,'InputMin',0,'InputMax',255);
end

参考资料

  1. The TensorFlow Team.Flowers http://download.tensorflow.org/example_images/flower_photos.tgz

  2. Radford, Alec, Luke Metz, and Soumith Chintala."Unsupervised representation learning with deep convolutional generative adversarial networks." arXiv preprint arXiv:1511.06434 (2015).

另请参阅

| | | | | | |

相关主题