Main Content

本页的翻译已过时。点击此处可查看最新英文版本。

训练变分自编码器 (VAE) 以生成图像

此示例说明如何在 MATLAB 中创建一个变分自编码器 (VAE) 以生成数字图像。VAE 生成 MNIST 数据集样式的手写数字。

VAE 不同于常规的自编码器,因为它们不使用编码解码过程来重新构造输入。它们在潜在空间上施加一个概率分布,并且学习该分布,使得来自解码器的输出的分布与观测到的数据的分布相匹配。然后,它们从这个分布中采样以生成新数据。

在此示例中,您构造一个 VAE 网络,基于 MNIST 数据集对其进行训练,并生成与该数据集非常相似的新图像。

加载数据

http://yann.lecun.com/exdb/mnist/ 下载 MNIST 文件,并将 MNIST 数据集加载到工作区中 [1]。调用此示例附带的 processImagesMNISTprocessLabelsMNIST 辅助函数,将文件中的数据加载到 MATLAB 数组中。

由于 VAE 是将重新构造的数字与输入进行比较,而不是与分类标签进行比较,因此您不需要使用 MNIST 数据集中的训练标签。

trainImagesFile = 'train-images-idx3-ubyte.gz';
testImagesFile = 't10k-images-idx3-ubyte.gz';
testLabelsFile = 't10k-labels-idx1-ubyte.gz';

XTrain = processImagesMNIST(trainImagesFile);
Read MNIST image data...
Number of images in the dataset:  60000 ...
numTrainImages = size(XTrain,4);
XTest = processImagesMNIST(testImagesFile);
Read MNIST image data...
Number of images in the dataset:  10000 ...
YTest = processLabelsMNIST(testLabelsFile);
Read MNIST label data...
Number of labels in the dataset:  10000 ...

构造网络

自编码器有两个部分:编码器和解码器。编码器接受一个图像输入并输出一个压缩表示(编码),该压缩表示是大小为 latentDim 的向量,在此示例中等于 20。解码器接受该压缩表示,对其进行解码,并重新创建原始图像。

为了使计算在数值上更加稳定,通过让网络从方差的对数中学习,将可能值的范围从 [0,1] 增加到 [-inf, 0]。定义大小为 latent_dim 的两个向量:一个用于均值 μ,另一个用于方差的对数 log(σ2)。然后使用这两个向量创建采样分布。

使用二维卷积,后跟一个全连接层,将图像从 28×28×1 MNIST 图像下采样至潜在空间中的编码。然后,使用转置的二维卷积将 1×1×20 编码放大回 28×28×1 图像。

latentDim = 20;
imageSize = [28 28 1];

encoderLG = layerGraph([
    imageInputLayer(imageSize,'Name','input_encoder','Normalization','none')
    convolution2dLayer(3, 32, 'Padding','same', 'Stride', 2, 'Name', 'conv1')
    reluLayer('Name','relu1')
    convolution2dLayer(3, 64, 'Padding','same', 'Stride', 2, 'Name', 'conv2')
    reluLayer('Name','relu2')
    fullyConnectedLayer(2 * latentDim, 'Name', 'fc_encoder')
    ]);

decoderLG = layerGraph([
    imageInputLayer([1 1 latentDim],'Name','i','Normalization','none')
    transposedConv2dLayer(7, 64, 'Cropping', 'same', 'Stride', 7, 'Name', 'transpose1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(3, 64, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(3, 32, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(3, 1, 'Cropping', 'same', 'Name', 'transpose4')
    ]);

要使用自定义训练循环训练两个网络并支持自动微分,请将层次图转换为 dlnetwork 对象。

encoderNet = dlnetwork(encoderLG);
decoderNet = dlnetwork(decoderLG);

定义模型梯度函数

辅助函数 modelGradients 接受编码器和解码器 dlnetwork 对象以及输入数据 X 的一个小批量,并返回损失关于网络中可学习参数的梯度。此辅助函数在此示例的末尾定义。

该函数分两步执行此过程:采样和损失。采样步骤对均值向量和方差向量进行采样,以创建要传递给解码器网络的最终编码。但是,由于无法通过随机采样操作进行反向传播,您必须使用重参数化方法。此方法将随机采样运算移至辅助变量 ε,然后通过按均值 μi 进行移位,并按标准差 σi 进行缩放。其想法是从 N(μi,σi2) 采样与从 μi+εσi 采样相同,其中 εN(0,1)。下图形象地说明了此想法。

损失步骤将在采样步骤生成的编码传入解码器网络进行处理,确定损失,然后使用损失计算梯度。VAE 中的损失,也称为证据下界 (ELBO) 损失,定义为两个独立损失项之和:

ELBOloss=reconstructionloss+KLloss

重建损失通过使用均方误差 (MSE) 来测量解码器输出与原始输入的接近程度:

reconstructionloss=MSE(decoderoutput,originalimage)

KL 损失,即 Kullback–Leibler 散度,测量两个概率分布之间的差异。在本例中,最小化 KL 损失意味着确保学习的均值和方差尽可能接近目标(正态)分布的均值和方差。对于大小为 n 的潜在维度,KL 损失的计算公式如下

KLloss=-0.5i=1n(1+log(σi2)-μi2-σi2)

包含 KL 损失项的实际效果是将由于重建损失而学习到的聚类紧密地聚集在潜在空间中心周围,形成连续的采样空间。

指定训练选项

在 GPU 上(如果有)训练(需要 Parallel Computing Toolbox™)。

executionEnvironment = "auto";

指定网络的训练选项。在使用 Adam 优化器时,您需要用空数组初始化每个网络的尾部平均梯度和尾部平均梯度平方衰减率。

numEpochs = 50;
miniBatchSize = 512;
lr = 1e-3;
numIterations = floor(numTrainImages/miniBatchSize);
iteration = 0;

avgGradientsEncoder = [];
avgGradientsSquaredEncoder = [];
avgGradientsDecoder = [];
avgGradientsSquaredDecoder = [];

训练模型

使用自定义训练循环训练模型。

对于一轮训练中的每次迭代:

  • 从训练集中获取下一个小批量。

  • 将小批量转换为 dlarray 对象,确保指定维度标签 'SSCB'(空间、空间、通道、批量)。

  • 对于 GPU 训练,将 dlarray 转换为 gpuArray 对象。

  • 使用 dlfeval modelGradients 函数计算模型梯度。

  • 使用 adamupdate 函数更新两个网络的网络可学习参数和平均梯度。

在每轮训练结束时,传入测试集图像,使之通过自编码器,并显示该轮训练的损失和训练时间。

for epoch = 1:numEpochs
    tic;
    for i = 1:numIterations
        iteration = iteration + 1;
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        XBatch = XTrain(:,:,:,idx);
        XBatch = dlarray(single(XBatch), 'SSCB');
        
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            XBatch = gpuArray(XBatch);           
        end 
            
        [infGrad, genGrad] = dlfeval(...
            @modelGradients, encoderNet, decoderNet, XBatch);
        
        [decoderNet.Learnables, avgGradientsDecoder, avgGradientsSquaredDecoder] = ...
            adamupdate(decoderNet.Learnables, ...
                genGrad, avgGradientsDecoder, avgGradientsSquaredDecoder, iteration, lr);
        [encoderNet.Learnables, avgGradientsEncoder, avgGradientsSquaredEncoder] = ...
            adamupdate(encoderNet.Learnables, ...
                infGrad, avgGradientsEncoder, avgGradientsSquaredEncoder, iteration, lr);
    end
    elapsedTime = toc;
    
    [z, zMean, zLogvar] = sampling(encoderNet, XTest);
    xPred = sigmoid(forward(decoderNet, z));
    elbo = ELBOloss(XTest, xPred, zMean, zLogvar);
    disp("Epoch : "+epoch+" Test ELBO loss = "+gather(extractdata(elbo))+...
        ". Time taken for epoch = "+ elapsedTime + "s")    
end
Epoch : 1 Test ELBO loss = 28.0145. Time taken for epoch = 28.0573s
Epoch : 2 Test ELBO loss = 24.8995. Time taken for epoch = 8.797s
Epoch : 3 Test ELBO loss = 23.2756. Time taken for epoch = 8.8824s
Epoch : 4 Test ELBO loss = 21.151. Time taken for epoch = 8.5979s
Epoch : 5 Test ELBO loss = 20.5335. Time taken for epoch = 8.8472s
Epoch : 6 Test ELBO loss = 20.232. Time taken for epoch = 8.5068s
Epoch : 7 Test ELBO loss = 19.9988. Time taken for epoch = 8.4356s
Epoch : 8 Test ELBO loss = 19.8955. Time taken for epoch = 8.4015s
Epoch : 9 Test ELBO loss = 19.7991. Time taken for epoch = 8.8089s
Epoch : 10 Test ELBO loss = 19.6773. Time taken for epoch = 8.4269s
Epoch : 11 Test ELBO loss = 19.5181. Time taken for epoch = 8.5771s
Epoch : 12 Test ELBO loss = 19.4532. Time taken for epoch = 8.4227s
Epoch : 13 Test ELBO loss = 19.3771. Time taken for epoch = 8.5807s
Epoch : 14 Test ELBO loss = 19.2893. Time taken for epoch = 8.574s
Epoch : 15 Test ELBO loss = 19.1641. Time taken for epoch = 8.6434s
Epoch : 16 Test ELBO loss = 19.2175. Time taken for epoch = 8.8641s
Epoch : 17 Test ELBO loss = 19.158. Time taken for epoch = 9.1083s
Epoch : 18 Test ELBO loss = 19.085. Time taken for epoch = 8.6674s
Epoch : 19 Test ELBO loss = 19.1169. Time taken for epoch = 8.6357s
Epoch : 20 Test ELBO loss = 19.0791. Time taken for epoch = 8.5512s
Epoch : 21 Test ELBO loss = 19.0395. Time taken for epoch = 8.4674s
Epoch : 22 Test ELBO loss = 18.9556. Time taken for epoch = 8.3943s
Epoch : 23 Test ELBO loss = 18.9469. Time taken for epoch = 10.2924s
Epoch : 24 Test ELBO loss = 18.924. Time taken for epoch = 9.8302s
Epoch : 25 Test ELBO loss = 18.9124. Time taken for epoch = 9.9603s
Epoch : 26 Test ELBO loss = 18.9595. Time taken for epoch = 10.9887s
Epoch : 27 Test ELBO loss = 18.9256. Time taken for epoch = 10.1402s
Epoch : 28 Test ELBO loss = 18.8708. Time taken for epoch = 9.9109s
Epoch : 29 Test ELBO loss = 18.8602. Time taken for epoch = 10.3075s
Epoch : 30 Test ELBO loss = 18.8563. Time taken for epoch = 10.474s
Epoch : 31 Test ELBO loss = 18.8127. Time taken for epoch = 9.8779s
Epoch : 32 Test ELBO loss = 18.7989. Time taken for epoch = 9.6963s
Epoch : 33 Test ELBO loss = 18.8. Time taken for epoch = 9.8848s
Epoch : 34 Test ELBO loss = 18.8095. Time taken for epoch = 10.3168s
Epoch : 35 Test ELBO loss = 18.7601. Time taken for epoch = 10.8058s
Epoch : 36 Test ELBO loss = 18.7469. Time taken for epoch = 9.9365s
Epoch : 37 Test ELBO loss = 18.7049. Time taken for epoch = 10.0343s
Epoch : 38 Test ELBO loss = 18.7084. Time taken for epoch = 10.3214s
Epoch : 39 Test ELBO loss = 18.6858. Time taken for epoch = 10.3985s
Epoch : 40 Test ELBO loss = 18.7284. Time taken for epoch = 10.9685s
Epoch : 41 Test ELBO loss = 18.6574. Time taken for epoch = 10.5241s
Epoch : 42 Test ELBO loss = 18.6388. Time taken for epoch = 10.2392s
Epoch : 43 Test ELBO loss = 18.7133. Time taken for epoch = 9.8177s
Epoch : 44 Test ELBO loss = 18.6846. Time taken for epoch = 9.6858s
Epoch : 45 Test ELBO loss = 18.6001. Time taken for epoch = 9.5588s
Epoch : 46 Test ELBO loss = 18.5897. Time taken for epoch = 10.4554s
Epoch : 47 Test ELBO loss = 18.6184. Time taken for epoch = 10.0317s
Epoch : 48 Test ELBO loss = 18.6389. Time taken for epoch = 10.311s
Epoch : 49 Test ELBO loss = 18.5918. Time taken for epoch = 10.4506s
Epoch : 50 Test ELBO loss = 18.5081. Time taken for epoch = 9.9671s

可视化结果

要可视化和解释结果,请使用辅助可视化函数。这些辅助函数在此示例的末尾定义。

VisualizeReconstruction 函数显示从每个类中随机选择的一个数字,并伴随显示数字通过自编码器后的重建版本。

VisualizeLatentSpace 函数接受测试图像通过编码器网络后生成的均值和方差编码(每一个的维度都为 20),并对包含每个图像编码的矩阵执行主成分分析 (PCA)。然后,您可以在由两个第一主成分表征的两个维度中可视化由均值和方差定义的潜在空间。

Generate 函数初始化从正态分布采样的新编码,并输出这些编码通过解码器网络时生成的图像。

visualizeReconstruction(XTest, YTest, encoderNet, decoderNet)

visualizeLatentSpace(XTest, YTest, encoderNet)

generate(decoderNet, latentDim)

后续步骤

变分自编码器只是用于执行生成式任务的众多可用模型之一。它们适用于图像较小且具有清晰定义的特征的数据集(如 MNIST)。对于包含较大图像的更复杂的数据集,生成对抗网络 (GAN) 往往表现更好,生成的图像噪声更低。有关如何实现 GAN 以生成 64×64 RGB 图像的示例,请参阅训练生成对抗网络 (GAN)

参考资料

  1. LeCun, Y., C. Cortes, and C. J. C. Burges."The MNIST Database of Handwritten Digits." http://yann.lecun.com/exdb/mnist/.

辅助函数

模型梯度函数

modelGradients 函数接受编码器和解码器 dlnetwork 对象以及输入数据 X 的一个小批量,并返回损失关于网络中可学习参数的梯度。该函数执行三项操作:

  1. 对传入编码器网络进行处理的小批量图像调用 sampling 函数来获取编码。

  2. 传入编码,使之通过解码器网络并调用 ELBOloss 函数来获得损失。

  3. 通过调用 dlgradient 函数,计算损失关于两个网络的可学习参数的梯度。

function [infGrad, genGrad] = modelGradients(encoderNet, decoderNet, x)
[z, zMean, zLogvar] = sampling(encoderNet, x);
xPred = sigmoid(forward(decoderNet, z));
loss = ELBOloss(x, xPred, zMean, zLogvar);
[genGrad, infGrad] = dlgradient(loss, decoderNet.Learnables, ...
    encoderNet.Learnables);
end

采样和损失函数

sampling 函数从输入图像中获取编码。最初,它将一个小批量图像传入编码器网络进行处理,并将大小为 (2*latentDim)*miniBatchSize 的输出拆分为均值矩阵和方差矩阵,每个矩阵的大小为 latentDim*batchSize。然后,它使用这些矩阵来实现重参数化方法并计算编码。最后,它将此编码转换为 SSCB 格式的 dlarray 对象。

function [zSampled, zMean, zLogvar] = sampling(encoderNet, x)
compressed = forward(encoderNet, x);
d = size(compressed,1)/2;
zMean = compressed(1:d,:);
zLogvar = compressed(1+d:end,:);

sz = size(zMean);
epsilon = randn(sz);
sigma = exp(.5 * zLogvar);
z = epsilon .* sigma + zMean;
z = reshape(z, [1,1,sz]);
zSampled = dlarray(z, 'SSCB');
end

ELBOloss 函数接受 sampling 函数返回的均值和方差的编码,并使用它们来计算 ELBO 损失。

function elbo = ELBOloss(x, xPred, zMean, zLogvar)
squares = 0.5*(xPred-x).^2;
reconstructionLoss  = sum(squares, [1,2,3]);

KL = -.5 * sum(1 + zLogvar - zMean.^2 - exp(zLogvar), 1);

elbo = mean(reconstructionLoss + KL);
end

可视化函数

VisualizeReconstruction 函数为 MNIST 数据集的每个数字随机选择两个图像,将它们传入 VAE 进行处理,并排绘制原始输入的图和重建的图。请注意,要绘制 dlarray 对象中包含的信息,您需要首先使用 extractdatagather 函数提取它。

function visualizeReconstruction(XTest,YTest, encoderNet, decoderNet)
f = figure;
figure(f)
title("Example ground truth image vs. reconstructed image")
for i = 1:2
    for c=0:9
        idx = iRandomIdxOfClass(YTest,c);
        X = XTest(:,:,:,idx);

        [z, ~, ~] = sampling(encoderNet, X);
        XPred = sigmoid(forward(decoderNet, z));
        
        X = gather(extractdata(X));
        XPred = gather(extractdata(XPred));

        comparison = [X, ones(size(X,1),1), XPred];
        subplot(4,5,(i-1)*10+c+1), imshow(comparison,[]),
    end
end
end

function idx = iRandomIdxOfClass(T,c)
idx = T == categorical(c);
idx = find(idx);
idx = idx(randi(numel(idx),1));
end

VisualizeLatentSpace 函数可视化由构成编码器网络输出的均值和方差矩阵定义的潜在空间,并定位由每个数字的潜在空间表示形成的簇。

该函数首先从 dlarray 对象中提取均值矩阵和方差矩阵。由于无法转置具有通道/批量维度(C 和 B)的矩阵,函数会在转置矩阵之前调用 stripdims。然后,它对两个矩阵进行主成分分析 (PCA)。为了在两个维度上可视化潜在空间,该函数保留前两个主成分,并绘制彼此对照的图。最后,该函数对数字类进行着色,以便您可以观察簇。

function visualizeLatentSpace(XTest, YTest, encoderNet)
[~, zMean, zLogvar] = sampling(encoderNet, XTest);

zMean = stripdims(zMean)';
zMean = gather(extractdata(zMean));

zLogvar = stripdims(zLogvar)';
zLogvar = gather(extractdata(zLogvar));

[~,scoreMean] = pca(zMean);
[~,scoreLogvar] = pca(zLogvar);

c = parula(10);
f1 = figure;
figure(f1)
title("Latent space")

ah = subplot(1,2,1);
scatter(scoreMean(:,2),scoreMean(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
axis equal
xlabel("Z_m_u(2)")
ylabel("Z_m_u(1)")
cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);

ah = subplot(1,2,2);
scatter(scoreLogvar(:,2),scoreLogvar(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
xlabel("Z_v_a_r(2)")
ylabel("Z_v_a_r(1)")
cb = colorbar;  cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);
axis equal
end

generate 函数测试 VAE 的生成能力。它初始化一个包含 25 个随机生成的编码的 dlarray 对象,将这些编码传入解码器网络进行处理,并绘制输出。

function generate(decoderNet, latentDim)
randomNoise = dlarray(randn(1,1,latentDim,25),'SSCB');
generatedImage = sigmoid(predict(decoderNet, randomNoise));
generatedImage = extractdata(generatedImage);

f3 = figure;
figure(f3)
imshow(imtile(generatedImage, "ThumbnailSize", [100,100]))
title("Generated samples of digits")
drawnow
end

另请参阅

| | | | | |

相关主题