Segment Lungs from CT Scan Using Pretrained Neural Network
This example shows how to import a pretrained ONNX™ (Open Neural Network Exchange) 3-D U-Net [1] and use it to perform semantic segmentation of the left and right lungs from a 3-D chest CT scan. Semantic segmentation associates each voxel in a 3-D image with a class label. In this example, you classify each voxel in a test data set as belonging to the left lung or right lung. For more information about semantic segmentation, see Semantic Segmentation (Computer Vision Toolbox).
A challenge of applying pretrained networks is the possibility of differences between the intensity and spatial details of a new data set and the data set used to train the network. Preprocessing is typically required to format the data to match the expected network input and achieve accurate segmentation results. In this example, you standardize the spatial orientation and normalize the intensity range of a test data set before applying the pretrained network.
Download Pretrained Network
Specify the desired location of the pretrained network.
dataDir = fullfile(tempdir,"lungmask"); if ~exist(dataDir,"dir") mkdir(dataDir); end
Download the pretrained network from the MathWorks® website by using the downloadTrainedNetwork
helper function. The helper function is attached to this example as a supporting file. The network on the MathWorks website is equivalent to the R231 model, available in the LungMask GitHub repository [2], converted to the ONNX format. The size of the pretrained network is approximately 11 MB.
lungmask_url = "https://www.mathworks.com/supportfiles/medical/pretrainedLungmaskR231Net.onnx";
downloadTrainedNetwork(lungmask_url,dataDir);
Downloading pretrained network. This can take several minutes to download... Done.
Import Pretrained Network
Import the ONNX network as a function by using the importONNXFunction
(Deep Learning Toolbox) function. You can use this function to import a network with layers that the importONNXNetwork
(Deep Learning Toolbox) function does not support. The importONNXFunction
function requires the Deep Learning Toolbox™ Converter for ONNX Model Format support package. If this support package is not installed, then importONNXFunction
provides a download link.
The importONNXFunction
function imports the network and returns an ONNXParameters
object that contains the network parameters. When you import the pretrained lung segmentation network, the function displays a warning that the LogSoftmax
operator is not supported.
modelfileONNX = fullfile(dataDir,"pretrainedLungmaskR231Net.onnx"); modelfileM = "importedLungmaskFcn_R231.m"; params = importONNXFunction(modelfileONNX,modelfileM);
Function containing the imported ONNX network architecture was saved to the file importedLungmaskFcn_R231.m. To learn how to use this function, type: help importedLungmaskFcn_R231.
Warning: Unable to import some ONNX operators or attributes. They may have been replaced by 'PLACEHOLDER' functions in the imported model function. 1 operator(s) : Operator 'LogSoftmax' is not supported with its current settings or in this context.
Open the generated function, importedLungmaskFcn_R231.m
, saved as an M file in the current directory. The function contains these lines of code that indicate that the unsupported LogSoftmax
operator is replaced with a placeholder:
% PLACEHOLDER FUNCTION FOR UNSUPPORTED OPERATOR (LogSoftmax):
[Vars.x460, NumDims.x460] = PLACEHOLDER(Vars.x459);
In the function definition, replace the placeholder code with this code. Save the updated function as lungmaskFcn_R231
. A copy of lungmaskFcn_R231
with the correct code is also attached to this example as a supporting file.
% Replacement for PLACEHOLDER FUNCTION FOR UNSUPPORTED OPERATOR (LogSoftmax): Vars.x460 = log(softmax(Vars.x459,'DataFormat','CSSB')); NumDims.x460 = NumDims.x459;
Save the network parameters in the ONNXParameters
object params
. Save the parameters in a new MAT file.
save("lungmaskParams_R231","params");
Load Data
Test the pretrained lung segmentation network on a test data set. The test data is a CT chest volume from the Medical Segmentation Decathlon data set [3]. Download the MedicalVolumNIfTIData
ZIP archive from the MathWorks website, then unzip the file. The ZIP file contains two CT chest volumes and corresponding label images, stored in the NIfTI file format. The total size of the data set is approximately 76 MB.
zipFile = matlab.internal.examples.downloadSupportFile("medical","MedicalVolumeNIfTIData.zip"); filePath = fileparts(zipFile); unzip(zipFile,filePath) dataFolder = fullfile(filePath,"MedicalVolumeNIfTIData");
Specify the file name of the first CT volume.
fileName = fullfile(dataFolder,"lung_027.nii.gz");
Create a medicalVolume
object for the CT volume and display it.
medVol = medicalVolume(fileName);
volshow(medVol,RenderingStyle="GradientOpacity");
Preprocess Test Data
Preprocess the test data to match the expected orientation and intensity range of the pretrained network. First, extract the voxel data from the medicalVolume
object.
V = medVol.Voxels;
Rotate the test image volume in the transverse plane to match the expected input orientation for the pretrained network. The network was trained using data oriented with the patient bed at the bottom of the image, so the test data must be oriented in the same direction. If you change the test data, you need to apply an appropriate spatial transformation to match the expected orientation for the network.
rotationAxis = [0 0 1]; volAligned = imrotate3(V,90,rotationAxis);
Display a slice of the rotated volume to check the updated orientation.
imshow(volAligned(:,:,150),[])
Use intensity normalization to rescale the range of voxel intensities in the region of interest to the range [0, 1], which is the range that the pretrained network expects. The first step in intensity normalization is to determine the range of intensity values within the region of interest. The values are in Hounsfield units. To determine the thresholds for the intensity range, plot a histogram of the voxel intensity values. Set the x- and y-limits of the histogram plot based on the minimum and maximum values. The histogram has two large peaks. The first peak corresponds to background pixels outside the body of the patient and air in the lungs. The second peak corresponds to soft tissue such as the heart and stomach.
figure histogram(V) xlim([min(V,[],"all") max(V,[],"all")]) ylim([0 2e6]) xlabel("Intensity [Hounsfield Units]") ylabel("Number of Voxels") xline([-1024 500],"red",LineWidth=1)
To limit the intensities to the region containing the majority of the tissue in the region of interest, select the thresholds for the intensity range as –1024 and 500.
th = [-1024 500];
Apply the preprocessLungCT
helper function to further preprocess the test image volume. The helper function is attached to this example as a supporting file. The preprocessLungCT
function performs these steps:
Resize each 2-D slice along the transverse dimension to the target size,
imSize
. Decreasing the number of voxels can improve prediction speed. Set the target size to 256-by-256 voxels.Crop the voxel intensities to the range specified by the thresholds in
th
.Normalize the updated voxel intensities to the range [0, 1].
imSize = [256 256]; volInp = preprocessLungCT(volAligned,imSize,th);
Segment Test Data and Postprocess Predicted Labels
Segment the test CT volume by using the lungSeg
helper function. The helper function is attached to this example as a supporting file. The lungSeg
function predicts the segmentation mask by performing inference on the pretrained network and postprocesses the network output to obtain the segmentation mask.
To decrease the required computational time, the lungSeg
function performs inference on the slices of a volume in batches. Specify the batch size as eight slices using the batchSize
name-value argument of lungSeg
. Increasing the batch size increases the speed of inference, but requires more memory. If you run out of memory, try deceasing the batch size.
During postprocessing, the lungSeg
helper function applies a mode filter to the network output to smooth the segmentation labels using the modefilt
function. You can set the size of the mode filter by using the modeFilt
name-value argument of lungSeg
. The default filter size is [9 9 9]
.
labelOut = lungSeg(volInp,batchSize=8);
Display Predicted Segmentation Labels
Display the segmentation results by using the volshow
function. Use the OverlayData
argument to plot the predicted segmentation labels. To focus on the label data, use the Alphamap
argument to set the opacity of the image volume to 0 and the OverlayAlpha
argument to set the opacity of the labels to 0.9.
volshow(volInp, ... Alphamap=0,... OverlayData=labelOut,... OverlayAlpha=0.9);
You can also display the preprocessed test volume as slice planes with the predicted segmentation labels as an overlay by setting the RenderingStyle
name-value argument to "SlicePlanes"
. Specify the lung segmentation label using the OverlayData
name-value argument.
volshow(volInp, ... RenderingStyle="SlicePlanes", ... OverlayData=labelOut,... OverlayAlpha=0.9);
Click and drag the mouse to rotate the volume. To scroll in a plane, pause on the slice you want to investigate until it becomes highlighted, then click and drag. The left and right lung segmentation masks are visible in the slices for which they are defined.
References
[1] Hofmanninger, Johannes, Forian Prayer, Jeanny Pan, Sebastian Röhrich, Helmut Prosch, and Georg Langs. “Automatic Lung Segmentation in Routine Imaging Is Primarily a Data Diversity Problem, Not a Methodology Problem.” European Radiology Experimental 4, no. 1 (December 2020): 50. https://doi.org/10.1186/s41747-020-00173-2.
[2] GitHub. “Automated Lung Segmentation in CT under Presence of Severe Pathologies.” Accessed July 21, 2022. https://github.com/JoHof/lungmask.
[3] Medical Segmentation Decathlon. "Lung." Tasks. Accessed May 10, 2018. http://medicaldecathlon.com. The Medical Segmentation Decathlon data set is provided under the CC-BY-SA 4.0 license. All warranties and representations are disclaimed. See the license for details.
See Also
importONNXFunction
(Deep Learning Toolbox) | modefilt
| volshow