Main Content

Train Fast Style Transfer Network

This example shows how to train a network to transfer the style of an image to a second image. It is based on the architecture defined in [1].

This example is similar to Neural Style Transfer Using Deep Learning, but it works faster once you have trained the network on a style image S. This is because, to obtain the stylized image Y you only need to do a forward pass of the input image X to the network.

Find a high-level diagram of the training algorithm below. This uses three images to calculate the loss: the input image X, the transformed image Y and the style image S.

Note that the loss function uses the pretrained network VGG-16 to extract features from the images. You can find its implementation and mathematical definition in the Style Transfer Loss section of this example.

Load Training Data

Download and extract the COCO 2014 train images and captions from https://cocodataset.org/#download by clicking the "2014 Train images". Save the data in the folder specified by imageFolder. Extract the images into imageFolder. The COCO 2014 was collected by the Coco Consortium.

Create directories to store the COCO data set.

imageFolder = fullfile(tempdir,"coco");
if ~exist(imageFolder,"dir")
    mkdir(imageFolder);
end

Create an image datastore containing the COCO images.

imds = imageDatastore(imageFolder,IncludeSubfolders=true);

Training can take a long time to run. If you want to decrease the training time at the cost of accuracy of the resulting network, then select a subset of the image datastore by setting fraction to a smaller value.

fraction = 1;
numObservations = numel(imds.Files);
imds = subset(imds,1:floor(numObservations*fraction));

To resize the images and convert them all to RGB, create an augmented image datastore.

augimds = augmentedImageDatastore([256 256],imds,ColorPreprocessing="gray2rgb");

Read the style image.

styleImage = imread("starryNight.jpg");
styleImage = imresize(styleImage,[256 256]);

Display the chosen style image.

figure
imshow(styleImage)
title("Style Image")

Define Image Transformer Network

Create a dlnetwork object.

netTransform = dlnetwork;

Specify the layers of the image transformer network and add them to the network. This is an image-to-image network. The network consists of 3 parts:

  1. The first part of the network takes as input an RGB image of size [256x256x3] and downsamples it to a feature map of size [64x64x128].

  2. The second part of the network consists of five identical residual blocks defined in the supporting function residualBlock.

  3. The third and final part of the network upsamples the feature map to the original size of the image and returns the transformed image. This last part uses the upsampleLayer, which is a custom layer attached to this example as a supporting file.

layers = [
    
    % First part.
    imageInputLayer([256 256 3],Normalization="none")
    
    convolution2dLayer([9 9],32,Padding="same")
    groupNormalizationLayer("channel-wise")
    reluLayer
    
    convolution2dLayer([3 3],64,Stride=2,Padding="same")
    groupNormalizationLayer("channel-wise")
    reluLayer
    
    convolution2dLayer([3 3],128,Stride=2,Padding="same")
    groupNormalizationLayer("channel-wise")
    reluLayer(Name="relu_3")
    
    % Second part. 
    residualBlock("1")
    residualBlock("2")
    residualBlock("3")
    residualBlock("4")
    residualBlock("5")
    
    % Third part.
    upsampleLayer
    convolution2dLayer([3 3],64,Padding="same")
    groupNormalizationLayer("channel-wise")
    reluLayer
    
    upsampleLayer
    convolution2dLayer([3 3],32,Padding="same")
    groupNormalizationLayer("channel-wise")
    reluLayer
    
    convolution2dLayer(9,3,Padding="same")];

netTransform = addLayers(netTransform,layers);

Add missing connections in residual blocks.

netTransform = connectLayers(netTransform,"relu_3","add_1/in2");
netTransform = connectLayers(netTransform,"add_1","add_2/in2");
netTransform = connectLayers(netTransform,"add_2","add_3/in2");
netTransform = connectLayers(netTransform,"add_3","add_4/in2");
netTransform = connectLayers(netTransform,"add_4","add_5/in2");

Initialize the network.

netTransform = initialize(netTransform);

Visualize the image transformer network in a plot.

figure
plot(netTransform)
title("Transform Network")

Style Loss Network

This example uses a pretrained VGG-16 deep neural network to extract the features of the content and style images at different layers. These multilayer features are used to compute respective content and style losses.

To get a pretrained VGG-16 network, use the imagePretrainedNetwork function. If you do not have the required support packages installed, then the software provides a download link.

netLoss = imagePretrainedNetwork("vgg16");

Define Model Loss Function

Create the function modelLoss, listed in the Model Loss Function section of the example. This function takes as input the loss network, the image transformer network, a mini-batch of input images, an array containing the Gram matrices of the style image, the weight associated with the content loss and the weight associated with the style loss. The function returns the total loss, the loss associated with the content and the loss associated with the style, the gradients of the total loss with respect to the learnable parameters of the image transformer, the state of the image transformer network, and the transformed images.

Specify Training Options

Train with a mini-batch size of 16 for 2 epochs.

numEpochs = 2;
miniBatchSize = 16;

Set the read size of the augmented image datastore to the mini-batch size.

augimds.MiniBatchSize = miniBatchSize;

Specify the options for Adam optimization. Specify a learn rate of 0.001 with a gradient decay factor of 0.01, and a squared gradient decay factor of 0.999.

learnRate = 0.001;
gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

Specify the weight given to the style loss and the one given to the content loss in the calculation of the total loss.

Note that, in order to find a good balance between content and style loss, you might need to experiment with different combinations of weights.

weightContent = 1e-4;
weightStyle = 3e-8; 

Choose the plot frequency of the training progress. This specifies how many iterations there are between each plot update.

plotFrequency = 100;

Train Model

In order to be able to compute the loss during training, calculate the Gram matrices for the style image.

Convert the style image to dlarray.

S = dlarray(single(styleImage),"SSC");

In order to calculate the Gram matrix, feed the style image to the VGG-16 network and extract the activations at four different layers.

[SActivations1,SActivations2,SActivations3,SActivations4] = forward(netLoss,S, ...
    Outputs=["relu1_2" "relu2_2" "relu3_3" "relu4_3"]);

Calculate the Gram matrix for each set of activations using the supporting function createGramMatrix.

SGram{1} = createGramMatrix(SActivations1);
SGram{2} = createGramMatrix(SActivations2);
SGram{3} = createGramMatrix(SActivations3);
SGram{4} = createGramMatrix(SActivations4);

The training plots consists of two figures:

  1. A figure showing a plot of the losses during training

  2. A figure containing an input and an output image of the image transformer network

Initialize the training plots. You can check the details of the initialization in the supporting function initializeFigures. This function returns: the axis ax1 where you plot the loss, the axis ax2 where you plot the validation images, the animated line lineLossContent which contains the content loss, the animated line lineLossStyle which contains the style loss and the animated line lineLossTotal which contains the total loss.

[ax1,ax2,lineLossContent,lineLossStyle,lineLossTotal] = initializeStyleTransferPlots;

Initialize the average gradient and average squared gradient hyperparameters for the Adam optimizer.

averageGrad = [];
averageSqGrad = [];

Calculate total number of training iterations.

numIterations = floor(augimds.NumObservations*numEpochs/miniBatchSize);

Initialize iteration number and timer before training.

iteration = 0;
start = tic;

Train the model. Train on a GPU if one is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). This could take a long time to run.

% Loop over epochs.
for i = 1:numEpochs
    
    % Reset and shuffle datastore.
    reset(augimds);
    augimds = shuffle(augimds);
    
    % Loop over mini-batches.
    while hasdata(augimds)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        data = read(augimds);
        
        % Ignore last partial mini-batch of epoch.
        if size(data,1) < miniBatchSize
            continue
        end
        
        % Extract the images from data store into a cell array.
        images = data{:,1};
        
        % Concatenate the images along the 4th dimension.
        X = cat(4,images{:});
        X = single(X);
        
        % Convert mini-batch of data to dlarray and specify the dimension labels
        % "SSCB" (spatial, spatial, channel, batch).
        X = dlarray(X,"SSCB");
        
        % If training on a GPU, then convert data to gpuArray.
        if canUseGPU
            X = gpuArray(X);
        end
        
        % Evaluate the model loss, gradients, and the network state using
        % dlfeval and the modelLoss function listed at the end of the
        % example.
        [loss,lossContent,lossStyle,gradients,state,Y] = dlfeval(@modelLoss, ...
            netLoss,netTransform,X,SGram,weightContent,weightStyle);
        
        netTransform.State = state;
        
        % Update the network parameters.
        [netTransform,averageGrad,averageSqGrad] = ...
            adamupdate(netTransform,gradients,averageGrad,averageSqGrad,iteration,...
            learnRate, gradientDecayFactor,squaredGradientDecayFactor);
              
        % Every plotFrequency iterations, plot the training progress.
        if mod(iteration,plotFrequency) == 0
            addpoints(lineLossTotal,iteration,double(loss))
            addpoints(lineLossContent,iteration,double(lossContent))
            addpoints(lineLossStyle,iteration,double(lossStyle))
            
            % Use the first image of the mini-batch as a validation image.
            XV = X(:,:,:,1);
            % Use the transformed validation image computed previously.
            YV = Y(:,:,:,1);
            
            % To use the function imshow, convert to uint8.
            validationImage = uint8(gather(extractdata(XV)));
            transformedValidationImage = uint8(gather(extractdata(YV)));
            
            % Plot the input image and the output image and increase size
            imshow(imtile({validationImage,transformedValidationImage}),Parent=ax2);
        end
        
        % Display time elapsed since start of training and training completion percentage.
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        completionPercentage = round(iteration/numIterations*100,2);
        title(ax1,"Epoch: " + i + ", Iteration: " + iteration +" of "+ numIterations + "(" + completionPercentage + "%)" +", Elapsed: " + string(D))
        drawnow
        
    end

end

Stylize an Image

Once training has finished, you can use the image transformer on any image of your choice.

Load the image you would like to transform.

imFilename = "peppers.png";
im = imread(imFilename);

Resize the input image to the input dimensions of the image transformer.

im = imresize(im,[256,256]);

Convert it to dlarray.

X = dlarray(single(im),"SSCB");

To use the GPU convert to gpuArray if one is available.

if canUseGPU
    X = gpuArray(X);
end

To apply the style to the image, forward pass it to the image transformer using the function predict.

Y = predict(netTransform,X);

Rescale the image into the range [0 255]. First, use the function tanh to rescale Y to the range [-1 1]. Then, shift and scale the output to rescale into the [0 255] range.

Y = 255*(tanh(Y)+1)/2;

Prepare Y for plotting. Use the function extractdata to extract the data from dlarray.Use the function gather to transfer Y from the GPU to the local workspace.

Y = uint8(gather(extractdata(Y)));

Show the input image (left) next to the stylized image (right).

figure
m = imtile({im,Y});
imshow(m)

Model Loss Function

The function modelLoss takes as input the loss network netLoss, the image transformer network netTransform, a mini-batch of input images X, an array containing the Gram matrices of the style image SGram, the weight associated with the content loss contentWeight and the weight associated with the style loss styleWeight. The function returns the total loss, the loss associated with the content lossContent and the loss associated with the style lossStyle, the gradients of the total loss with respect to the learnable parameters of the image transformer gradients, the state of the image transformer network state, and the transformed images Y.

function [loss,lossContent,lossStyle,gradients,state,Y] = ...
    modelLoss(netLoss,netTransform,X,SGram,contentWeight,styleWeight)

[Y,state] = forward(netTransform,X);

Y = 255*(tanh(Y)+1)/2;

[loss,lossContent,lossStyle] = styleTransferLoss(netLoss,Y,X,SGram,contentWeight,styleWeight);

gradients = dlgradient(loss,netTransform.Learnables);

end

Style Transfer Loss

The function styleTransferLoss takes as input the loss network netLoss, a mini-batch of input images X, a mini-batch of transformed images Y, an array containing the Gram matrices of the style image SGram, the weights associated with the content and style contentWeight and styleWeight, respectively. It returns the total loss loss and the individual components: the content loss lossContent and the style loss lossStyle.

The content loss is a measure of how much difference in spatial structure there is between the input image X and the output images Y.

On the other hand, the style loss tells you how much difference in the stylistic appearance there is between the style image S and the output image Y.

The graph below explains the algorithm that styleTransferLoss implements to calculate the total loss.

First, the function passes the input images X, the transformed images Y and the style image S to the pretrained network VGG-16. This pretrained network extracts several features from these images. The algorithm then calculates the content loss by using the spatial features of the input image X and of the output image Y. Moreover, it calculates the style loss by using the stylistic features of the output image Y and of the style image S. Finally, it obtains the total loss by adding the content and style losses.

Content Loss

For each image in the mini-batch, the content loss function compares the features of the original image and of the transformed image output by the layer relu3_3. In particular, it calculates the mean square error between the activations and returns the average loss for the mini-batch:

lossContent=1Nn=1Nmean([ϕ(Xn)-ϕ(Yn)]2),

where X contains the input images, Y contains the transformed images, N is the mini-batch size, and ϕ() represents the activations extracted at layer relu3_3.

Style Loss

To calculate the style loss, for each single image in the mini-batch:

  1. Extract the activations at the layers relu1_2, relu2_2, relu3_3 and relu4_3.

  2. For each of the four activations ϕj compute the Gram matrix G(ϕj).

  3. Calculate the squared difference between the corresponding Gram matrices.

  4. Add up the four outputs for each layer j from the previous step.

To obtain the style loss for the whole mini-batch, compute the average of the style loss for each image n in the mini-batch:

lossStyle=1Nn=1Nj=14[G(ϕj(Xn))-G(ϕj(S))]2,

where j is the index of the layer, and G() is the Gram Matrix.

Total Loss

function [loss,lossContent,lossStyle] = styleTransferLoss(netLoss,Y,X, ...
    SGram,weightContent,weightStyle)

% Extract activations.
YActivations = cell(1,4);
[YActivations{1},YActivations{2},YActivations{3},YActivations{4}] = ...
    forward(netLoss,Y,Outputs=["relu1_2" "relu2_2" "relu3_3" "relu4_3"]);

XActivations = forward(netLoss,X,Outputs="relu3_3");

% Calculate the mean square error between activations.
lossContent = mean((YActivations{3} - XActivations).^2,"all");

% Add up the losses for all the four activations.
lossStyle = 0;
for j = 1:4
    G = createGramMatrix(YActivations{j});
    lossStyle = lossStyle + sum((G - SGram{j}).^2,"all");
end

% Average the loss over the mini-batch.
miniBatchSize = size(X,4);
lossStyle = lossStyle/miniBatchSize;

% Apply weights.
lossContent = weightContent * lossContent;
lossStyle = weightStyle * lossStyle;

% Calculate the total loss.
loss = lossContent + lossStyle;

end

Residual Block

The residualBlock function returns an array of six layers. It consists of convolution layers, instance normalization layers, a ReLu layer and an addition layer. Note that groupNormalizationLayer('channel-wise') is simply an instance normalization layer.

function layers = residualBlock(name)

layers = [    
    convolution2dLayer([3 3], 128,Padding="same",Name="convRes_"+name+"_1")
    groupNormalizationLayer("channel-wise",Name="normRes_"+name+"_1")
    reluLayer(Name="reluRes_"+name+"_1")
    convolution2dLayer([3 3],128,Padding="same",Name="convRes_"+name+"_2")
    groupNormalizationLayer("channel-wise",Name="normRes_"+name+"_2")
    additionLayer(2,Name="add_"+name)];

end

Gram Matrix

The function createGramMatrix takes as an input the activations of a single layer and returns a stylistic representation for each image in a mini-batch. The input is a feature map of size [H, W, C, N], where H is the height, W is the width, C is the number of channels and N is the mini-batch size. The function outputs an array G of size [C,C,N]. Each subarray G(:,:,k) is the Gram matrix corresponding to the kth image in the mini-batch. Each entry G(i,j,k) of the Gram matrix represents the correlation between channels ci and cj, because each entry in channel ci multiplies the entry in the corresponding position in channel cj:

G(i,j,k)=1C×H×Wh=1Hw=1Wϕk(h,w,ci)ϕk(h,w,cj),

where ϕk are the activations for the kth image in the mini-batch.

The Gram matrix contains information about which features activate together but has no information about where the features occur in the image. This is because the summation over height and width loses the information about the spatial structure. The loss function uses this matrix as a stylistic representation of the image.

function G = createGramMatrix(activations)

[h,w,numChannels] = size(activations,1:3);

features = reshape(activations,h*w,numChannels,[]);
featuresT = permute(features,[2 1 3]);

G = dlmtimes(featuresT,features) / (h*w*numChannels);

end

References

  1. Johnson, Justin, Alexandre Alahi, and Li Fei-Fei. "Perceptual losses for real-time style transfer and super-resolution." European conference on computer vision. Springer, Cham, 2016.

See Also

| | | | | |

Related Topics