Image Processing Operator Approximation Using Deep Learning
This example shows how to approximate an image filtering operation using a multiscale context aggregation network (CAN).
Operator approximation finds alternative ways to process images such that the result resembles the output from a conventional image processing operation or pipeline. The goal of operator approximation is often to reduce the time required to process an image.
Several classical and deep learning techniques have been proposed to perform operator approximation. Some classical techniques improve the efficiency of a single algorithm but cannot be generalized to other operations. Another common technique approximates a wide range of operations by applying the operator to a low resolution copy of an image, but the loss of high-frequency content limits the accuracy of the approximation.
Deep learning solutions enable the approximation of more general and complex operations. For example, the multiscale context aggregation network (CAN) presented by Q. Chen [1] can approximate multiscale tone mapping, photographic style transfer, nonlocal dehazing, and pencil drawing. Multiscale CAN trains on full-resolution images for greater accuracy in processing high-frequency details. After the network is trained, the network can bypass the conventional processing operation and process images directly.
This example explores how to train a multiscale CAN to approximate a bilateral image filtering operation, which reduces image noise while preserving edge sharpness. The example presents the complete training and inference workflow, which includes the process of creating a training datastore, selecting training options, training the network, and using the network to process test images.
The Operator Approximation Network
The multiscale CAN is trained to minimize the loss between the conventional output of an image processing operation and the network response after processing the input image using multiscale context aggregation. Multiscale context aggregation looks for information about each pixel from across the entire image, rather than limiting the search to a small neighborhood surrounding the pixel.
To help the network learn global image properties, the multiscale CAN architecture has a large receptive field. The first and last layers have the same size because the operator should not change the size of the image. Successive intermediate layers are dilated by exponentially increasing scale factors (hence the "multiscale" nature of the CAN). Dilation enables the network to look for spatially separated features at various spatial frequencies, without reducing the resolution of the image. After each convolution layer, the network uses adaptive normalization to balance the impact of batch normalization and the identity mapping on the approximated operator.
Download Training Data
Download the IAPR TC-12 Benchmark, which consists of 20,000 still natural images [2]. The data set includes photos of people, animals, cities, and more. The size of the data file is ~1.8 GB. If you do not want to download the training data set needed to train the network, then you can load the pretrained CAN by typing load("trainedBilateralFilterNet_v2.mat");
at the command line. Then, go directly to the Perform Bilateral Filtering Approximation Using Multiscale CAN section in this example.
Use the helper function, downloadIAPRTC12Data
, to download the data. This function is attached to the example as a supporting file. Specify dataDir
as the desired location of the data.
dataDir = tempdir;
downloadIAPRTC12Data(dataDir);
Downloading IAPR TC-12 dataset... This will take several minutes to download and unzip... Done.
This example trains the network with small subset of the IAPRTC-12 Benchmark data.
trainImagesDir = fullfile(dataDir,"iaprtc12","images","39"); exts = [".jpg",".bmp",".png"]; pristineImages = imageDatastore(trainImagesDir,FileExtensions=exts);
List the number of training images.
numel(pristineImages.Files)
ans = 916
Prepare Training Data
To create a training data set, read in pristine images and write out images that have been bilateral filtered. The filtered images are stored on disk in the directory specified by preprocessDataDir
.
preprocessDataDir = fullfile(trainImagesDir,"preprocessedDataset");
Use the helper function bilateralFilterDataset
to preprocess the training data. This function is attached to the example as a supporting file. The helper function performs these operations for each pristine image in inputImages
:
Calculate the degree of smoothing for bilateral filtering. Smoothing the filtered image reduces image noise.
Perform bilateral filtering using
imbilatfilt
.Save the filtered image to disk using
imwrite
.
bilateralFilterDataset(pristineImages,preprocessDataDir);
Define Random Patch Extraction Datastore for Training
Use a random patch extraction datastore to feed the training data to the network. This datastore extracts random corresponding patches from two image datastores that contain the network inputs and desired network responses.
In this example, the network inputs are the pristine images in pristineImages
. The desired network responses are the processed images after bilateral filtering. Create an image datastore called bilatFilteredImages
from the collection of bilateral filtered image files.
bilatFilteredImages = imageDatastore(preprocessDataDir,FileExtensions=exts);
Create a randomPatchExtractionDatastore
from the two image datastores. Specify a patch size of 256-by-256 pixels. Specify the PatchesPerImage
name-value argument to extract one randomly-positioned patch from each pair of images during training. Specify a mini-batch size of one.
miniBatchSize = 1;
patchSize = [256 256];
dsTrain = randomPatchExtractionDatastore(pristineImages,bilatFilteredImages, ...
patchSize,PatchesPerImage=1);
dsTrain.MiniBatchSize = miniBatchSize;
The randomPatchExtractionDatastore
provides mini-batches of data to the network at each iteration of the epoch. Perform a read operation on the datastore to explore the data.
inputBatch = read(dsTrain); disp(inputBatch)
InputImage ResponseImage _________________ _________________ {256×256×3 uint8} {256×256×3 uint8}
Set Up Multiscale CAN Layers
This example defines the multiscale CAN using layers from Deep Learning Toolbox™, including:
imageInputLayer
(Deep Learning Toolbox) — Image input layerconvolution2dLayer
(Deep Learning Toolbox) — 2-D convolution layer for convolutional neural networksbatchNormalizationLayer
(Deep Learning Toolbox) — Batch normalization layerleakyReluLayer
(Deep Learning Toolbox) — Leaky rectified linear unit layer
Two custom scale layers are added to implement an adaptive batch normalization layer. These layers are attached as supporting files to this example.
adaptiveNormalizationMu
— Scale layer that adjusts the strengths of the batch-normalization branchadaptiveNormalizationLambda
— Scale layer that adjusts the strengths of the identity branch
The first layer, imageInputLayer
, operates on image patches. The patch size is based on the network receptive field, which is the spatial image region that affects the response of top-most layer in the network. Ideally, the network receptive field is the same as the image size so that it can see all the high level features in the image. For a bilateral filter, the approximation image patch size is fixed to 256-by-256.
networkDepth = 10; numberOfFilters = 32; firstLayer = imageInputLayer([256 256 3],Name="InputLayer",Normalization="none");
The image input layer is followed by a 2-D convolution layer that contains 32 filters of size 3-by-3. Zero-pad the inputs to each convolution layer so that feature maps remain the same size as the input after each convolution. Initialize the weights to the identity matrix.
Wgts = zeros(3,3,3,numberOfFilters); for ii = 1:3 Wgts(2,2,ii,ii) = 1; end convolutionLayer = convolution2dLayer(3,numberOfFilters,Padding=1, ... Weights=Wgts,Name="Conv1");
Each convolution layer is followed by a batch normalization layer and an adaptive normalization scale layer that adjusts the strengths of the batch-normalization branch. Later, this example will create the corresponding adaptive normalization scale layer that adjusts the strength of the identity branch. For now, follow the adaptiveNormalizationMu
layer with an addition layer. Finally, specify a leaky ReLU layer with a scalar multiplier of 0.2
for negative inputs.
batchNorm = batchNormalizationLayer(Name="BN1"); adaptiveMu = adaptiveNormalizationMu(numberOfFilters,"Mu1"); addLayer = additionLayer(2,Name="add1"); leakyrelLayer = leakyReluLayer(0.2,Name="Leaky1");
Specify the middle layers of the network following the same pattern. Successive convolution layers have a dilation factor that scales exponentially with the network depth.
midLayers = [convolutionLayer batchNorm adaptiveMu addLayer leakyrelLayer]; Wgts = zeros(3,3,numberOfFilters,numberOfFilters); for ii = 1:numberOfFilters Wgts(2,2,ii,ii) = 1; end for layerNumber = 2:networkDepth-2 dilationFactor = 2^(layerNumber-1); padding = dilationFactor; conv2dLayer = convolution2dLayer(3,numberOfFilters, ... Padding=padding,DilationFactor=dilationFactor, ... Weights=Wgts,Name="Conv"+num2str(layerNumber)); batchNorm = batchNormalizationLayer(Name="BN"+num2str(layerNumber)); adaptiveMu = adaptiveNormalizationMu(numberOfFilters,"Mu"+num2str(layerNumber)); addLayer = additionLayer(2,Name="add"+num2str(layerNumber)); leakyrelLayer = leakyReluLayer(0.2,Name="Leaky"+num2str(layerNumber)); midLayers = [midLayers conv2dLayer batchNorm adaptiveMu addLayer leakyrelLayer]; end
Do not apply a dilation factor to the second-to-last convolution layer.
conv2dLayer = convolution2dLayer(3,numberOfFilters, ... Padding=1,Weights=Wgts,Name="Conv9"); batchNorm = batchNormalizationLayer(Name="AN9"); adaptiveMu = adaptiveNormalizationMu(numberOfFilters,"Mu9"); addLayer = additionLayer(2,Name="add9"); leakyrelLayer = leakyReluLayer(0.2,Name="Leaky9"); midLayers = [midLayers conv2dLayer batchNorm adaptiveMu addLayer leakyrelLayer];
The last convolution layer has a single filter of size 1-by-1-by-32-by-3 that reconstructs the image.
Wgts = sqrt(2/(9*numberOfFilters))*randn(1,1,numberOfFilters,3); finalLayer = convolution2dLayer(1,3,NumChannels=numberOfFilters, ... Weights=Wgts,Name="Conv10");
Concatenate all the layers.
layers = [firstLayer midLayers finalLayer];
Create an empty dlnetwork
object and add the layers to the object.
opNet = dlnetwork; opNet = addLayers(opNet,layers);
Create skip connections, which act as the identity branch for the adaptive normalization equation. Connect the skip connections to the addition layers.
skipConv1 = adaptiveNormalizationLambda(numberOfFilters,"Lambda1"); skipConv2 = adaptiveNormalizationLambda(numberOfFilters,"Lambda2"); skipConv3 = adaptiveNormalizationLambda(numberOfFilters,"Lambda3"); skipConv4 = adaptiveNormalizationLambda(numberOfFilters,"Lambda4"); skipConv5 = adaptiveNormalizationLambda(numberOfFilters,"Lambda5"); skipConv6 = adaptiveNormalizationLambda(numberOfFilters,"Lambda6"); skipConv7 = adaptiveNormalizationLambda(numberOfFilters,"Lambda7"); skipConv8 = adaptiveNormalizationLambda(numberOfFilters,"Lambda8"); skipConv9 = adaptiveNormalizationLambda(numberOfFilters,"Lambda9"); opNet = addLayers(opNet,skipConv1); opNet = connectLayers(opNet,"Conv1","Lambda1"); opNet = connectLayers(opNet,"Lambda1","add1/in2"); opNet = addLayers(opNet,skipConv2); opNet = connectLayers(opNet,"Conv2","Lambda2"); opNet = connectLayers(opNet,"Lambda2","add2/in2"); opNet = addLayers(opNet,skipConv3); opNet = connectLayers(opNet,"Conv3","Lambda3"); opNet = connectLayers(opNet,"Lambda3","add3/in2"); opNet = addLayers(opNet,skipConv4); opNet = connectLayers(opNet,"Conv4","Lambda4"); opNet = connectLayers(opNet,"Lambda4","add4/in2"); opNet = addLayers(opNet,skipConv5); opNet = connectLayers(opNet,"Conv5","Lambda5"); opNet = connectLayers(opNet,"Lambda5","add5/in2"); opNet = addLayers(opNet,skipConv6); opNet = connectLayers(opNet,"Conv6","Lambda6"); opNet = connectLayers(opNet,"Lambda6","add6/in2"); opNet = addLayers(opNet,skipConv7); opNet = connectLayers(opNet,"Conv7","Lambda7"); opNet = connectLayers(opNet,"Lambda7","add7/in2"); opNet = addLayers(opNet,skipConv8); opNet = connectLayers(opNet,"Conv8","Lambda8"); opNet = connectLayers(opNet,"Lambda8","add8/in2"); opNet = addLayers(opNet,skipConv9); opNet = connectLayers(opNet,"Conv9","Lambda9"); opNet = connectLayers(opNet,"Lambda9","add9/in2");
Initialize the dlnetwork
object using a sample input image.
opNet = initialize(opNet,dlarray(single(inputBatch.InputImage{1}),"SSCB"));
Visualize the network using the Deep Network Designer (Deep Learning Toolbox) app.
deepNetworkDesigner(opNet)
Specify Training Options
Train the network using the Adam optimizer. Specify the hyperparameter settings by using the trainingOptions
(Deep Learning Toolbox) function. Use the default values of 0.9
for "Momentum"
and 0.0001
for "L2Regularization"
(weight decay). Specify a constant learning rate of 0.0001
. Train for 181 epochs.
maxEpochs = 181; initLearningRate = 0.0001; miniBatchSize = 1; options = trainingOptions("adam", ... InitialLearnRate=initLearningRate, ... MaxEpochs=maxEpochs, ... MiniBatchSize=miniBatchSize, ... Plots="training-progress", ... Verbose=false);
Train the Network
By default, the example loads a pretrained multiscale CAN that approximates a bilateral filter. The pretrained network enables you to perform an approximation of bilateral filtering without waiting for training to complete.
To train the network, set the doTraining
variable in the following code to true
. Train the neural network using the trainnet
(Deep Learning Toolbox) function. For regression, use mean squared error loss. By default, the trainnet
function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainnet
function uses the CPU. To specify the execution environment, use the ExecutionEnvironment
training option. Training takes about 15 hours on an NVIDIA™ Titan X GPU.
doTraining = false; if doTraining net = trainnet(dsTrain,opNet,"mean-squared-error",options); modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss")); save("trainedBilateralFilterNet_v2-"+modelDateTime+".mat","net"); else load("trainedBilateralFilterNet_v2.mat"); end
Perform Bilateral Filtering Approximation Using Multiscale CAN
To process an image using a trained multiscale CAN network that approximates a bilateral filter, follow the remaining steps of this example. The remainder of the example shows how to:
Create a sample noisy input image from a reference image.
Perform conventional bilateral filtering of the noisy image using the
imbilatfilt
function.Perform an approximation to bilateral filtering on the noisy image using the CAN.
Visually compare the denoised images from operator approximation and conventional bilateral filtering.
Evaluate the quality of the denoised images by quantifying the similarity of the images to the pristine reference image.
Create Sample Noisy Image
Create a sample noisy image that will be used to compare the results of operator approximation to conventional bilateral filtering.
The test data set, testImages
, contains 20 undistorted images shipped in Image Processing Toolbox™. Load the images into an imageDatastore
and display the images in a montage.
fileNames = ["sherlock.jpg","peacock.jpg","fabric.png","greens.jpg", ... "hands1.jpg","kobi.png","lighthouse.png","office_4.jpg", ... "onion.png","pears.png","yellowlily.jpg","indiancorn.jpg", ... "flamingos.jpg","sevilla.jpg","llama.jpg","parkavenue.jpg", ... "strawberries.jpg","trailer.jpg","wagon.jpg","football.jpg"]; filePath = fullfile(matlabroot,"toolbox","images","imdata")+filesep; filePathNames = strcat(filePath,fileNames); testImages = imageDatastore(filePathNames);
Display the test images as a montage.
montage(testImages)
Select one of the images to use as the reference image for bilateral filtering. Convert the image to data type uint8
.
testImage = "fabric.png";
Ireference = imread(testImage);
Ireference = im2uint8(Ireference);
You can optionally use your own image as the reference image. Note that the size of the test image must be at least 256-by-256. If the test image is smaller than 256-by-256, then increase the image size by using the imresize
function. The network also requires an RGB test image. If the test image is grayscale, then convert the image to RGB by using the cat
function to concatenate three copies of the original image along the third dimension.
Display the reference image.
imshow(Ireference)
title("Pristine Reference Image")
Use the imnoise
function to add zero-mean Gaussian white noise with a variance of 0.00001 to the reference image.
Inoisy = imnoise(Ireference,"gaussian",0.00001); imshow(Inoisy) title("Noisy Image")
Filter Image Using Bilateral Filtering
Conventional bilateral filtering is a standard way to reduce image noise while preserving edge sharpness. Use the imbilatfilt
function to apply a bilateral filter to the noisy image. Specify a degree of smoothing equal to the variance of pixel values.
degreeOfSmoothing = var(double(Inoisy(:)));
Ibilat = imbilatfilt(Inoisy,degreeOfSmoothing);
imshow(Ibilat)
title("Denoised Image Obtained Using Bilateral Filtering")
Process Image Using Trained Network
Pass the normalized input image through the trained network and observe the output. The output of the network is the desired denoised image, returned as a dlarray
object. Extract the image data from the dlarray
object by using the extractdata
function.
InoisyDL = dlarray(single(Inoisy),"SSCB");
IapproxDL = predict(net,InoisyDL);
Iapprox = extractdata(IapproxDL);
Image Processing Toolbox™ requires floating point images to have pixel values in the range [0, 1]. Use the rescale
function to scale the pixel values to this range, then convert the image to uint8
.
Iapprox = rescale(Iapprox);
Iapprox = im2uint8(Iapprox);
imshow(Iapprox)
title("Denoised Image Obtained Using Multiscale CAN")
Visual Comparison
To get a better visual understanding of the denoised images, examine a small region inside each image. Specify a region of interest (ROI) using vector roi
in the format [x y width height]. The elements define the x- and y-coordinate of the top left corner, and the width and height of the ROI.
roi = [300 30 50 50];
Crop the images to this ROI, and display the result as a montage. The CAN removes more noise than conventional bilateral filtering. Both techniques preserve edge sharpness.
montage({imcrop(Ireference,roi),imcrop(Inoisy,roi), ... imcrop(Ibilat,roi),imcrop(Iapprox,roi)}, ... Size=[1 4]); title("Reference Image | Noisy Image | Bilateral-Filtered Image | CAN Prediction");
Quantitative Comparison
Use image quality metrics to quantitatively compare the noisy input image, the bilateral-filtered image, and the operator-approximated image. The reference image is the original reference image, Ireference
, before adding noise.
Measure the peak signal-to-noise ratio (PSNR) of each image against the reference image. Larger PSNR values generally indicate better image quality. See psnr
for more information about this metric.
noisyPSNR = psnr(Inoisy,Ireference); bilatPSNR = psnr(Ibilat,Ireference); approxPSNR = psnr(Iapprox,Ireference); PSNR_Score = [noisyPSNR bilatPSNR approxPSNR]';
Measure the structural similarity index (SSIM) of each image. SSIM assesses the visual impact of three characteristics of an image: luminance, contrast and structure, against a reference image. The closer the SSIM value is to 1, the better the test image agrees with the reference image. See ssim
for more information about this metric.
noisySSIM = ssim(Inoisy,Ireference); bilatSSIM = ssim(Ibilat,Ireference); approxSSIM = ssim(Iapprox,Ireference); SSIM_Score = [noisySSIM bilatSSIM approxSSIM]';
Measure perceptual image quality using the Naturalness Image Quality Evaluator (NIQE). Smaller NIQE scores indicate better perceptual quality. See niqe
for more information about this metric.
noisyNIQE = niqe(Inoisy); bilatNIQE = niqe(Ibilat); approxNIQE = niqe(Iapprox); NIQE_Score = [noisyNIQE bilatNIQE approxNIQE]';
Display the metrics in a table.
table(PSNR_Score,SSIM_Score,NIQE_Score, ... RowNames=["Noisy Image","Bilateral Filtering","Operator Approximation"])
ans=3×3 table
PSNR_Score SSIM_Score NIQE_Score
__________ __________ __________
Noisy Image 20.276 0.76228 11.574
Bilateral Filtering 25.753 0.91534 6.924
Operator Approximation 26.173 0.92976 5.8165
References
[1] Chen, Q. J. Xu, and V. Koltun. "Fast Image Processing with Fully-Convolutional Networks." In Proceedings of the 2017 IEEE Conference on Computer Vision. Venice, Italy, Oct. 2017, pp. 2516-2525.
[2] Grubinger, M., P. Clough, H. Müller, and T. Deselaers. "The IAPR TC-12 Benchmark: A New Evaluation Resource for Visual Information Systems." Proceedings of the OntoImage 2006 Language Resources For Content-Based Image Retrieval. Genoa, Italy. Vol. 5, May 2006, p. 10.
See Also
randomPatchExtractionDatastore
| dlnetwork
(Deep Learning Toolbox) | trainingOptions
(Deep Learning Toolbox) | trainnet
(Deep Learning Toolbox) | predict
(Deep Learning Toolbox) | imbilatfilt
Related Topics
- Preprocess Images for Deep Learning (Deep Learning Toolbox)
- Datastores for Deep Learning (Deep Learning Toolbox)
- List of Deep Learning Layers (Deep Learning Toolbox)