Main Content

Brain MRI Segmentation Using Pretrained 3-D U-Net Network

This example shows how to segment brain MRI using a deep neural network.

Segmentation of brain scans enables the visualization of individual brain structures. Brain segmentation is also commonly used for quantitative volumetric and shape analyses to characterize healthy and diseased populations. Manual segmentation by clinical experts is considered the highest standard in segmentation. However, the process is extremely time-consuming and not practical for labeling large data sets. Additionally, labeling requires expertise in neuroanatomy and is prone to errors and limitations in interrater and intrarater reproducibility. Trained segmentation algorithms, such as convolutional neural networks, have the potential to automate the labeling of large clinical data sets.

In this example, you use the pretrained SynthSeg neural network [1], a 3-D U-Net for brain MRI segmentation. SynthSeg can be used to segment brain scans of any contrast and resolution without retraining or fine-tuning. SynthSeg is also robust to a wide array of subject populations, from young and healthy to aging and diseased subjects, and a wide array of scan conditions, such as white matter lesions, with or without preprocessing, including bias field corruption, skull stripping, intensity normalization, and template registration.

Download Brain MRI and Label Data

This example uses a subset of the CANDI data set [2] [3]. The subset consists of a brain MRI volume and the corresponding ground truth label volume for one healthy patient. Both files are in the NIfTI file format. The total size of the data files is ~2.5 MB.

Create a folder in which to store the data set. In this example, a folder named brainSegData created within the tempdir directory has been used as dataDir.

dataDir = fullfile(tempdir,"brainSegData");
if ~exist(dataDir,"dir")   
    mkdir(dataDir);
end

To download the data, go to the MR Session repository for the patient on the NeuroImaging Tools & Resources Collaboratory (NITRC) website. Under Actions, select Download > Download Images. In the window that opens, click Download. Unzip the downloaded folder, and navigate to the first subdirectory, HC_001_MR, within the unzipped folder. Save a copy of the HC_001_MR subfolder in the folder specified by dataDir.

Load Pretrained Network

This example uses a pretrained TensorFlow-Keras convolutional neural network. Download the pretrained network from the MathWorks® website by using the helper function downloadTrainedNetwork. The helper function is attached to this example as a supporting file. The size of the pretrained network is approximately 51 MB.

trainedBrainCANDINetwork_url = "https://www.mathworks.com/supportfiles/image/data/trainedBrainSynthSegNetwork.h5";
downloadTrainedNetwork(trainedBrainCANDINetwork_url,dataDir);

Load Test Data

Read the metadata from the brain MRI volume by using the niftiinfo function. Read the brain MRI volume by using the niftiread function.

imFile = fullfile(dataDir,"HC_001_MR","SCANS","anat","NIfTI","anat.nii.gz");
metaData = niftiinfo(imFile);
X = niftiread(metaData);

In this example, you segment the brain into 32 classes corresponding to anatomical structures. Read the names and numeric identifiers for each class label by using the getBrainCANDISegmentationLabels helper function. The helper function is attached to this example as a supporting file.

labelDirs = fullfile(dataDir,"HC_001_MR","ASSESSORS","HC_001_MR_seg","NIfTI");
[classNames,labelIDs] = getBrainCANDISegmentationLabels;

Preprocess Test Data

Preprocess the MRI volume by using the preProcessBrainCANDIData helper function. The helper function is attached to this example as a supporting file. The helper function performs these steps:

  • Resampling — If resample is true, resample the data to the isotropic voxel size 1-by-1-by-1 mm. By default, resample is false and the function does not perform resampling. To test the pretrained network on images with a different voxel size, set resample to true if the input is not isotropic.

  • Alignment — Rotate the volume to a standardized RAS orientation.

  • Cropping — Crop the volume to a maximum size of 192 voxels in each dimension.

  • Normalization — Normalize the intensity values of the volume to values in the range [0, 1], which improves the contrast.

resample = false;
cropSize = 192;
[X1,cropIdx,imSize] = preProcessBrainCANDIData(X,metaData,cropSize,resample);
inputSize = size(X1);

Convert the preprocessed MRI volume into a formatted deep learning array with the SSSCB (spatial, spatial, spatial, channel, batch) format by using dlarray (Deep Learning Toolbox).

X2 = dlarray(X1,"SSSCB");

Define Network Architecture

Import the network layers from the downloaded model file of the pretrained network using the importKerasLayers (Deep Learning Toolbox) function. The importKerasLayers function requires the Deep Learning Toolbox™ Converter for TensorFlow Models support package. If this support package is not installed, then importKerasLayers provides a download link. Specify ImportWeights as true to import the layers using the weights from the same HDF5 file. The function returns a layerGraph (Deep Learning Toolbox) object.

The Keras network contains some layers that the Deep Learning Toolbox™ does not support. The importKerasLayers function displays a warning and replaces the unsupported layers with placeholder layers.

modelFile = fullfile(dataDir,"trainedBrainSynthSegNetwork.h5");
lgraph = importKerasLayers(modelFile,ImportWeights=true,ImageInputSize=inputSize);
Warning: Imported layers have no output layer because the model does not specify a loss function. Add an output layer or use the 'OutputLayerType' name-value argument when you call importKerasLayers.
Warning: Unable to import some Keras layers, because they are not supported by the Deep Learning Toolbox. They have been replaced by placeholder layers. To find these layers, call the function findPlaceholderLayers on the returned object.

To replace the placeholder layers in the imported network, first identify the names of the layers to replace. Find the placeholder layers using findPlaceholderLayers (Deep Learning Toolbox).

placeholderLayers = findPlaceholderLayers(lgraph)
placeholderLayers = 
  PlaceholderLayer with properties:

                  Name: 'unet_prediction'
    KerasConfiguration: [1×1 struct]
               Weights: []

   Learnable Parameters
    No properties.

   State Parameters
    No properties.

  Show all properties

Define existing layers with the same configurations as the imported Keras layers.

sf = softmaxLayer;

Replace the placeholder layers with existing layers using replaceLayer (Deep Learning Toolbox).

lgraph = replaceLayer(lgraph,"unet_prediction",sf);

Convert the network to a dlnetwork (Deep Learning Toolbox) object.

net = dlnetwork(lgraph);

Display the updated layer graph information.

layerGraph(net)
ans = 
  LayerGraph with properties:

         Layers: [60×1 nnet.cnn.layer.Layer]
    Connections: [63×2 table]
     InputNames: {'unet_input'}
    OutputNames: {1×0 cell}

Predict Using Test Data

Predict Network Output

Predict the segmentation output for the preprocessed MRI volume. The segmentation output predictIm contains 32 channels corresponding to the segmentation label classes, such as "background", "leftCerebralCortex", "rightThalamus". The predictIm output assigns confidence scores to each voxel for every class. The confidence scores reflect the likelihood of the voxel being part of the corresponding class. This prediction is different from the final semantic segmentation output, which assigns each voxel to exactly one class.

predictIm = predict(net,X2);

Test Time Augmentation

This example uses test time augmentation to improve segmentation accuracy. In general, augmentation applies random transformations to an image to increase the variability of a data set. You can use augmentation before network training to increase the size of the training data set. Test time augmentation applies random transformations to test images to create multiple versions of the test image. You can then pass each version of the test image to the network for prediction. The network calculates the overall segmentation result as the average prediction for all versions of the test image. Test time augmentation improves segmentation accuracy by averaging out random errors in the individual network predictions.

By default, this example flips the MRI volume in the left-right direction, resulting in a flipped volume flippedData. The network output for the flipped volume is flipPredictIm. Set flipVal to false to skip the test time augmentation and speed up prediction.

flipVal = true;
if flipVal
    flippedData = fliplr(X1);  
    flippedData = flip(flippedData,2);
    flippedData = flip(flippedData,1);
    flippedData = dlarray(flippedData,"SSSCB");
    flipPredictIm = predict(net,flippedData);
else
    flipPredictIm = [];  
end

Postprocess Segmentation Prediction

To get the final segmentation maps, postprocess the network output by using the postProcessBrainCANDIData helper function. The helper function is attached to this example as a supporting file. The postProcessBrainCANDIData function performs these steps:

  • Smoothing — Apply a 3-D Gaussian smoothing filter to reduce noise in the predicted segmentation masks.

  • Morphological Filtering — Keep only the largest connected component of predicted segmentation masks to remove additional noise.

  • Segmentation — Assign each voxel to the label class with the greatest confidence score for that voxel.

  • Resizing — Resize the segmentation map to the original input volume size. Resizing the label image allows you to visualize the labels as an overlay on the grayscale MRI volume.

  • Alignment — Rotate the segmentation map back to the orientation of the original input MRI volume.

The final segmentation result, predictedSegMaps, is a 3-D categorical array the same size as the original input volume. Each element corresponds to one voxel and has one categorical label.

predictedSegMaps = postProcessBrainCANDIData(predictIm,flipPredictIm,imSize,...
    cropIdx,metaData,classNames,labelIDs);

Overlay a slice from the predicted segmentation map on a corresponding slice from the input volume using the labeloverlay function. Include all the brain structure labels except the background label.

B = labeloverlay(rescale(X(:,:,80)),predictedSegMaps(:,:,80),"IncludedLabels",2:32);
figure
montage({rescale(X(:,:,80)), B})

Figure contains an axes object. The axes object contains an object of type image.

Quantify Segmentation Accuracy

Measure the segmentation accuracy by comparing the predicted segmentation labels with the ground truth labels drawn by clinical experts.

Create a pixelLabelDatastore (Computer Vision Toolbox) to store the labels. Because the NIfTI file format is a nonstandard image format, you must use a NIfTI file reader to read the pixel label data. You can use the helper NIfTI file reader, niftiReader, defined at the bottom of this example.

pxds = pixelLabelDatastore(labelDirs,classNames,labelIDs,FileExtensions=".gz",...
    ReadFcn=@niftiReader);

Read the ground truth labels from the pixel label datastore.

groundTruthLabel = read(pxds);
groundTruthLabel = groundTruthLabel{1};

Measure the segmentation accuracy using the dice function. This function computes the Dice index between the predicted and ground truth segmentations.

diceResult = zeros(length(classNames),1);
for j = 1:length(classNames)
    diceResult(j)= dice(groundTruthLabel==classNames(j),...
        predictedSegMaps==classNames(j));
end

Calculate the average Dice index across all labels for the MRI volume.

meanDiceScore = mean(diceResult);
disp("Average Dice score across all labels = " +num2str(meanDiceScore))
Average Dice score across all labels = 0.80793

This figure shows a boxplot that visualizes statistics about the Dice indices across all the label classes. The red lines in the plot show the median Dice index. The upper and lower bounds of the blue box indicate the 25th and 75th percentiles, respectively. Black whiskers extend to the most extreme data points that are not outliers.

If you have a Statistics and Machine Learning Toolbox™ license, then you can use the boxplot (Statistics and Machine Learning Toolbox) function to visualize statistics about the Dice indices. To create a boxplot, set createBoxplot to true.

createBoxplot = false;
if createBoxplot
    figure
    boxplot(diceResult)
    title("Dice Accuracy")
    xticklabels("All Label Classes")
    ylabel("Dice Coefficient")
end

Supporting Functions

The niftiReader helper function is a custom read function for reading a NIfTI file in a datastore.

function data = niftiReader(filename)
% Read nifti file and convert to data type uint8.

% Copyright 2022 The MathWorks, Inc.

    data = niftiread(filename);
    data = uint8(data);
end

References

[1] Billot, Benjamin, Douglas N. Greve, Oula Puonti, Axel Thielscher, Koen Van Leemput, Bruce Fischl, Adrian V. Dalca, and Juan Eugenio Iglesias. “SynthSeg: Domain Randomisation for Segmentation of Brain Scans of Any Contrast and Resolution.” ArXiv:2107.09559 [Cs, Eess], December 21, 2021. http://arxiv.org/abs/2107.09559.

[2] “NITRC: CANDI Share: Schizophrenia Bulletin 2008: Tool/Resource Info.” Accessed October 17, 2022. https://www.nitrc.org/projects/cs_schizbull08/.

[3] Frazier, J. A., S. M. Hodge, J. L. Breeze, A. J. Giuliano, J. E. Terry, C. M. Moore, D. N. Kennedy, et al. “Diagnostic and Sex Effects on Limbic Volumes in Early-Onset Bipolar Disorder and Schizophrenia.” Schizophrenia Bulletin 34, no. 1 (October 27, 2007): 37–46. https://doi.org/10.1093/schbul/sbm120.

See Also

| (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Computer Vision Toolbox) | | (Statistics and Machine Learning Toolbox)

Related Topics