Main Content

使用深度学习进行三维脑肿瘤分割

此示例说明如何基于三维医学图像执行脑肿瘤的语义分割。

语义分割中,图像中的每个像素或三维体的每个体素都被标注为某一类。此示例在磁共振成像 (MRI) 扫描中使用三维 U-Net 深度学习网络执行脑肿瘤二元语义分割。U-Net 是一种快速、高效、简单的网络,在语义分割领域 [1] 非常流行。

医学图像分割面临的挑战之一是存储和处理三维体所需的内存量。由于 GPU 资源的限制,基于整个输入数据体训练网络并执行分割是不切实际的。此示例通过将图像分成更小的补片(或块)以训练和分割来解决问题。

医学图像分割面临的另一挑战是当使用常规的交叉熵损失时,数据中的类不平衡会妨碍训练。此示例通过使用加权多类 Dice 损失函数 [4] 解决此问题。对类进行加权有助于抵消较大区域对 Dice 分数的影响,使网络更容易学习如何分割较小区域。

此示例说明如何使用预训练的三维 U-Net 架构来执行脑肿瘤分割,以及如何使用一组测试图像来评估网络性能。您可以选择基于 BraTS 数据集 [2] 训练三维 U-Net。

加载预训练的三维 U-Net

将预训练的三维 U-Net 下载到名为 trainedNet 的变量中。您可以使用预训练网络来运行示例,而无需等待训练完成。

dataDir = fullfile(tempdir,"BraTS");
if ~exist(dataDir,"dir")
    mkdir(dataDir);
end
trained3DUnetURL = "https://www.mathworks.com/supportfiles/"+ ...
    "vision/data/brainTumorSegmentation3DUnet_v2.zip";
downloadTrainedNetwork(trained3DUnetURL,dataDir);
load(fullfile(dataDir,"brainTumorSegmentation3DUnet_v2.mat"));

执行语义分割

使用预训练网络来预测测试 MRI 数据体的肿瘤标签。

加载 BraTS 样本数据

此示例使用 BraTS 数据集 [2]。该完整数据集包含来自 750 名患者的脑肿瘤的标注 MRI 扫描。要尝试使用预训练网络,请从 MathWorks® 网站下载五个 MRI 扫描及其对应标签的子集。使用 downloadBraTSSampleTestData 辅助函数 [3] 下载样本数据,该函数作为支持文件包含在示例中。

downloadBraTSSampleTestData(dataDir);

加载其中一个 MRI 数据体及其像素标签真实值。每个数据体是一个四维数组,其中前三个维度对应于图像数据的高度、宽度和深度,第四个维度的每个页包含一个不同 MRI 模态。

testDir = fullfile(dataDir,"sampleBraTSTestSetValid");
data = load(fullfile(testDir,"imagesTest","BraTS446.mat"));
labels = load(fullfile(testDir,"labelsTest","BraTS446.mat"));
volTest = data.cropVol;
volTestLabels = labels.cropLabel;

分割分块图像中的脑肿瘤

为了高效处理大量数据,此示例将每个 MRI 扫描作为 blockedImage 对象中的一系列三维块进行处理。使用 apply 函数,该网络会预测每个块的标签,然后将这些块重新组合为一个完整的分段数据体。

为下载的样本数据体创建一个 blockedImage 对象。

bim = blockedImage(volTest);

apply 函数对 blockedImage 中的每个块都执行自定义函数。将 semanticsegBlock 定义为应用于每个块的函数。为了预测 classNames 中类的标签,semanticsegBlock 函数使用 semanticseg (Computer Vision Toolbox) 函数应用预训练网络。

classNames = ["background","tumor"];
semanticsegBlock = @(bstruct)semanticseg(bstruct.Data,trainedNet,Classes=classNames);

当您调用 apply 函数时,您可以指定块大小和块之间的边界大小等选项。指定块大小以匹配网络输出大小。通过将与网络输入大小相同的示例图像传递给预训练网络来获得网络输出大小。

networkInputSize = trainedNet.Layers(1).InputSize;

sampleImage = rand(networkInputSize,"single");
sampleOutput = predict(trainedNet,sampleImage);

networkOutputSize = size(sampleOutput);

blockSize = [networkOutputSize(1:3) networkInputSize(end)];

要创建重叠块,请指定非零的边界大小。此示例使用边界大小,使块加边界的大小等于网络输入大小。

borderSize = (networkInputSize(1:3) - blockSize(1:3))/2;

semanticsegBlock 函数应用于测试图像的所有块。将不完整块的填充指定为 true。默认填充方法 "replicate" 就很适用,因为数据体数据包含多个模态。将批量大小指定为 1,以防止内存资源有限的机器出现内存不足错误。但是,如果您的机器有足够的内存,则可以通过增大块的大小来提高处理速度。

batchSize = 1;
results = apply(bim, ...
    semanticsegBlock, ...
    BlockSize=blockSize, ...
    BorderSize=borderSize, ...
    PadPartialBlocks=true, ...
    BatchSize=batchSize, ...
    UseParallel=canUseGPU);
predictedLabels = gather(results);

使用蒙太奇显示真实值和预测标签沿深度方向的中心切片。

zID = size(volTest,3)/2;
zSliceGT = labeloverlay(volTest(:,:,zID),volTestLabels(:,:,zID));
zSlicePred = labeloverlay(volTest(:,:,zID),predictedLabels(:,:,zID));

figure
montage({zSliceGT,zSlicePred},Size=[1 2],BorderSize=5) 
title("Labeled Ground Truth (Left) vs. Network Prediction (Right)")

此动画显示一个测试数据体的真实值和预测标签的横向切片的并排视图。标注的真实值在左侧,网络预测在右侧。

Animation scrolling through the transverse slices of the ground truth and predicted labels for one test volume

准备要训练的数据

下载完整的 BraTS 数据集

要执行训练,您必须下载完整的 BraTS 数据集。数据集的总大小约为 7 GB,其中包含来自 750 名患者的数据。如果您不想下载训练数据集或训练网络,则可以跳到此示例的评估网络性能部分。

要下载 BraTS 数据,请访问 Medical Segmentation Decathlon 网站,然后点击 Download Data 链接。下载 Task01_BrainTumour.tar 文件 [3]。将该 TAR 文件解压缩到由 imageDir 变量指定的目录中。成功解压缩后,imageDir 包含一个名为 Task01_BrainTumour 的目录,该目录有三个子目录:imagesTrimagesTslabelsTr

与样本数据集类似,每个扫描均为一个四维数组,其中前三个维度对应于三维图像的高度、宽度和深度,第四个维度的每页对应于一个不同模态。该数据集分为 484 个带体素标签的训练数据体和 266 个测试数据体。测试数据体没有标签,因此此示例不使用测试数据。在这种情况下,示例将 484 个训练数据体拆分成三个独立的数据集,分别用于训练、验证和测试。

预处理数据

为了更高效地训练三维 U-Net 网络,通过使用 preprocessBraTSDataset 辅助函数预处理 MRI 数据。此函数作为支持文件包含在本示例中。该辅助函数执行以下操作:

  • 将数据裁剪到包含大脑和肿瘤的区域。裁剪可以减小数据的大小,同时保留每个 MRI 数据体的最关键部分及其对应标签。

  • 通过减去均值并除以裁剪后的大脑区域的标准差,独立地对每个数据体的每个形态进行归一化。

  • 将 484 个训练数据体拆分成 400 个训练集、29 个验证集和 55 个测试集。

数据预处理可能需要大约 30 分钟才能完成。

sourceDataLoc = fullfile(dataDir,"Task01_BrainTumour");
preprocessDataLoc = fullfile(dataDir,"preprocessedDataset");
if ~isfolder(preprocessDataLoc)
    preprocessBraTSDataset(preprocessDataLoc,sourceDataLoc);
end

为训练和验证创建随机补片提取数据存储

要读取和管理三维训练图像数据,请创建一个 imageDatastore 对象。指定自定义读取函数 matRead,以从 MAT 文件中读取图像数据。matRead 辅助函数作为支持文件包含在本示例中。

volLocTrain = fullfile(preprocessDataLoc,"imagesTr");
voldsTrain = imageDatastore(volLocTrain,FileExtensions=".mat",ReadFcn=@matRead);

要读取和管理标签,请创建一个 pixelLabelDatastore (Computer Vision Toolbox) 对象。指定与为 classNames 变量定义的类名称相同的类名称。像素标签 ID 1 映射到 "tumor" 类名称,像素标签 ID 0 映射到 "background" 类名称。

disp(classNames)
    "background"    "tumor"
pixelLabelID = [0 1];
lblLocTrain = fullfile(preprocessDataLoc,"labelsTr");

pxdsTrain = pixelLabelDatastore(lblLocTrain,classNames,pixelLabelID, ...
    FileExtensions=".mat",ReadFcn=@matRead);

创建一个 randomPatchExtractionDatastore 对象,它从真实值图像提取随机补片及其对应的像素标签数据。指定 132×132×132 体素的补片大小。指定 "PatchesPerImage" 名称-值参量以在训练期间从每对数据体和标签中提取 16 个随机定位的补片。指定小批量大小为 4。

patchSize = [132 132 132];
patchPerImage = 16;
miniBatchSize = 4;
patchdsTrain = randomPatchExtractionDatastore(voldsTrain,pxdsTrain,patchSize, ...
    PatchesPerImage=patchPerImage);
patchdsTrain.MiniBatchSize = miniBatchSize;

同样,要管理验证数据集,请创建 imageDatastorepixeLabelDatastorerandomPatchExtractionDatastore 对象。您可以使用验证数据来评估网络在训练过程中是在持续学习、欠拟合还是过拟合。

volLocVal = fullfile(preprocessDataLoc,"imagesVal");
voldsVal = imageDatastore(volLocVal,FileExtensions=".mat", ...
    ReadFcn=@matRead);

lblLocVal = fullfile(preprocessDataLoc,"labelsVal");
pxdsVal = pixelLabelDatastore(lblLocVal,classNames,pixelLabelID, ...
    FileExtensions=".mat",ReadFcn=@matRead);

patchdsVal = randomPatchExtractionDatastore(voldsVal,pxdsVal,patchSize, ...
    PatchesPerImage=patchPerImage);
patchdsVal.MiniBatchSize = miniBatchSize;

定义三维 U-Net 网络架构

此示例使用三维 U-Net 网络 [1]。U-Net 中的初始卷积层序列与最大池化层交叠,从而逐步降低输入图像的分辨率。这些层后跟一系列使用上采样算子散布的卷积层,从而会连续增加输入图像的分辨率。在每个 ReLU 层之前引入一个批量归一化层。U-Net 的名称源于网络可以绘制成形似字母 U 的对称形状。

使用 unet3d (Computer Vision Toolbox) 函数创建一个默认的三维 U-Net 网络。指定二类分割。为了避免边界伪影,还要指定有效的卷积填充。

numChannels = 4;
inputPatchSize = [patchSize numChannels];
numClasses = 2;
[net,outPatchSize] = unet3d(inputPatchSize, ...
    numClasses,ConvolutionPadding="valid");

通过使用 transform 函数和 augmentAndCrop3dPatch 辅助函数指定的自定义预处理操作来增强训练数据和验证数据。此函数作为支持文件包含在本示例中。augmentAndCrop3dPatch 函数执行以下操作:

  1. 随机旋转和翻转训练数据,使训练更加稳健。该函数不旋转或翻转验证数据。

  2. 将响应补片裁剪为网络的输出大小 44×44×44 体素。

dsTrain = transform(patchdsTrain, ...
    @(patchIn)augmentAndCrop3dPatch(patchIn,outPatchSize,"Training"));
dsVal = transform(patchdsVal, ...
    @(patchIn)augmentAndCrop3dPatch(patchIn,outPatchSize,"Validation"));

由于数据已在此示例的准备用于训练的数据部分中进行归一化处理,因此不需要在 image3dInputLayer (Deep Learning Toolbox) 中进行数据归一化处理,请将输入层替换为不进行数据归一化处理的输入层。

inputLayer = image3dInputLayer(inputPatchSize, ...
    Normalization="none",Name="ImageInputLayer");
net = replaceLayer(net,net.Layers(1).Name,inputLayer);

您也可以使用深度网络设计器修改三维 U-Net 网络。

deepNetworkDesigner(net)

训练三维 U-Net

指定训练选项

使用 adam 优化求解器来训练网络。使用 trainingOptions (Deep Learning Toolbox) 函数指定超参数设置。初始学习率设置为 5e-4,并在训练期间逐渐降低。您可以根据您的 GPU 内存情况尝试调整 MiniBatchSize 属性。为了最大限度地利用 GPU 内存,最好使用大输入补片而不是大批量大小。请注意,批量归一化层对于较小的 MiniBatchSize 值不是很有效。根据 MiniBatchSize 调整初始学习率。

options = trainingOptions("adam", ...
    MaxEpochs=50, ...
    InitialLearnRate=5e-4, ...
    LearnRateSchedule="piecewise", ...
    LearnRateDropPeriod=5, ...
    LearnRateDropFactor=0.95, ...
    ValidationData=dsVal, ...
    ValidationFrequency=400, ...
    Plots="training-progress", ...
    Verbose=false, ...
    MiniBatchSize=miniBatchSize);

定义损失函数

定义一个自定义损失函数 generalizedDiceLoss,该函数接受预测值 Y 和目标值 T 并返回广义 Dice 损失。广义 Dice 相似性系数用于衡量两个分割图像之间的重叠程度。该系数基于 Sørensen-Dice 相似性,并通过按预期区域的大小的倒数对类进行加权,以控制每个类对相似性所做的贡献。有关详细信息,请参阅 generalizedDice (Computer Vision Toolbox)

function loss = generalizedDiceLoss(Y,T)
% Copyright 2024 The MathWorks, Inc.

% Ignore any NaNs introduced to the training data during augmentation
T(isnan(T)) = 0;

z = generalizedDice(Y,T);

% Compute the mean of the Dice loss across the batch
loss = 1 - mean(z,"all");

end

训练网络

默认情况下,该示例使用下载的预训练三维 U-Net 网络。借助预训练网络,您无需等待训练完成,即可执行语义分割和评估分割结果。

要训练网络,请将以下代码中的 doTraining 变量设置为 true。使用 trainnet (Deep Learning Toolbox) 函数训练网络。默认情况下,trainnet 函数使用 GPU(如果有)。在 GPU 上进行训练需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。有关受支持设备的信息,请参阅GPU Computing Requirements (Parallel Computing Toolbox)。否则,trainnet 函数使用 CPU。要指定执行环境,请使用 ExecutionEnvironment 训练选项。

doTraining = false;
if doTraining
    [trainedNet,info] = trainnet(dsTrain,net,@generalizedDiceLoss,options);
    modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss"));
    save("trained3DUNet-"+modelDateTime+".mat","trainedNet");
end

评估三维 U-Net

指定用于测试网络的数据的位置。默认情况下,useFullTestSetfalse,该示例使用在加载样本 BraTS 数据部分中下载的五个样本数据体。如果将 useFullTestSet 值更改为 true,则该示例将使用完整数据集中分配给测试的 55 个扫描。

useFullTestSet = false;
if useFullTestSet
    volLocTest = fullfile(preprocessDataLoc,"imagesTest");
    lblLocTest = fullfile(preprocessDataLoc,"labelsTest");
else
    volLocTest = fullfile(testDir,"imagesTest");
    lblLocTest = fullfile(testDir,"labelsTest");
end

与训练数据集和验证数据集类似,通过分别创建 imageDatastore 对象和 pixelLabelDatastore 对象来管理测试数据集的图像数据和标签。

voldsTest = imageDatastore(volLocTest,FileExtensions=".mat", ...
    ReadFcn=@matRead);

pxdsTest = pixelLabelDatastore(lblLocTest,classNames,pixelLabelID, ...
    FileExtensions=".mat",ReadFcn=@matRead);

对于每个测试数据体,将数据体作为 blockedImage 对象读取,并使用 apply 函数处理每个块。apply 函数执行由 calculateBlockMetrics 辅助函数指定的运算,该辅助函数在此示例的末尾定义。calculateBlockMetrics 函数执行每个块的语义分割,并计算预测标签和真实值标签之间的混淆矩阵。

imageIdx = 1;
datasetConfMat = table;
while hasdata(voldsTest)

    % Read volume and label data
    vol = read(voldsTest);
    volLabels = read(pxdsTest);

    % Create blockedImage for volume and label data
    testVolume = blockedImage(vol);
    testLabels = blockedImage(volLabels{1});

    % Calculate block metrics
    blockConfMatOneImage = apply(testVolume, ...
        @(block,labeledBlock) ...
            calculateBlockMetrics(block,labeledBlock,trainedNet,classNames), ...
        ExtraImages=testLabels, ...
        PadPartialBlocks=true, ...
        BlockSize=blockSize, ...
        BorderSize=borderSize, ...
        UseParallel=false);

    % Read all the block results of an image and update the image number
    blockConfMatOneImageDS = blockedImageDatastore(blockConfMatOneImage);
    blockConfMat = readall(blockConfMatOneImageDS);
    blockConfMat = struct2table([blockConfMat{:}]);
    blockConfMat.ImageNumber = imageIdx.*ones(height(blockConfMat),1);
    datasetConfMat = [datasetConfMat; blockConfMat];

    imageIdx = imageIdx + 1;
end

使用 evaluateSemanticSegmentation (Computer Vision Toolbox) 函数评估分割的数据集度量和数据块度量。

[metrics,blockMetrics] = evaluateSemanticSegmentation( ...
    datasetConfMat,classNames,Metrics="all");
Evaluating semantic segmentation results
----------------------------------------
* Selected metrics: global accuracy, class accuracy, IoU, weighted IoU.
* Processed 5 images.
* Finalizing... Done.
* Data set metrics:

    GlobalAccuracy    MeanAccuracy    MeanIoU    WeightedIoU
    ______________    ____________    _______    ___________

       0.99902          0.97292       0.95915      0.99807  

显示为每个图像计算的杰卡德分数。

metrics.ImageMetrics.MeanIoU
ans = 5×1

    0.9676
    0.9521
    0.9568
    0.9537
    0.9635

支持函数

calculateBlockMetrics 辅助函数执行块的语义分割,并计算预测标签和真实值标签之间的混淆矩阵。该函数返回一个结构体,其字段包含关于该块的混淆矩阵和元数据。您可以将该结构体与 evaluateSemanticSegmentation 函数结合使用来计算度量和聚合基于块的结果。

function blockMetrics = calculateBlockMetrics(bstruct,gtBlockLabels,net,classNames)

% Segment block
predBlockLabels = semanticseg(bstruct.Data,net,Classes=classNames);

% Trim away border region from gtBlockLabels 
blockStart = bstruct.BorderSize + 1;
blockEnd = blockStart + bstruct.BlockSize - 1;
gtBlockLabels = gtBlockLabels( ...
    blockStart(1):blockEnd(1), ...
    blockStart(2):blockEnd(2), ...
    blockStart(3):blockEnd(3));

% Evaluate segmentation results against ground truth
confusionMat = segmentationConfusionMatrix(predBlockLabels,gtBlockLabels);

% blockMetrics is a struct with confusion matrices, image number,
% and block information 
blockMetrics.ConfusionMatrix = confusionMat;
blockMetrics.ImageNumber = bstruct.ImageNumber;
blockInfo.Start = bstruct.Start;
blockInfo.End = bstruct.End;
blockMetrics.BlockInfo = blockInfo;

end

参考资料

[1] Çiçek, Özgün, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, and Olaf Ronneberger.“3D U-Net:Learning Dense Volumetric Segmentation from Sparse Annotation.”In Medical Image Computing and Computer-Assisted Intervention – MICCAI 2016, edited by Sebastien Ourselin, Leo Joskowicz, Mert R. Sabuncu, Gozde Unal, and William Wells, 9901:424–32.Cham:Springer International Publishing, 2016. https://doi.org/10.1007/978-3-319-46723-8_49.

[2] Isensee, Fabian, Philipp Kickingereder, Wolfgang Wick, Martin Bendszus, and Klaus H. Maier-Hein.“Brain Tumor Segmentation and Radiomics Survival Prediction:Contribution to the BRATS 2017 Challenge.”In Brainlesion:Glioma, Multiple Sclerosis, Stroke and Traumatic Brain Injuries, edited by Alessandro Crimi, Spyridon Bakas, Hugo Kuijf, Bjoern Menze, and Mauricio Reyes, 10670:287–97.Cham:Springer International Publishing, 2018. https://doi.org/10.1007/978-3-319-75238-9_25.

[3] "Brain Tumours".Medical Segmentation Decathlon. http://medicaldecathlon.com/

BraTS 数据集由 Medical Segmentation Decathlon 在 CC-BY-SA 4.0 许可证下提供。所有保证和陈述均有相应的免责声明;有关详细信息,请参阅许可文档。MathWorks® 已修改此示例中加载样本 BraTS 数据部分中链接的数据集。修改后的样本数据集已裁剪为主要包含大脑和肿瘤的区域,并且通过减去均值并除以裁剪后的大脑区域的标准差,独立地对每个通道进行归一化。

[4] Sudre, Carole H., Wenqi Li, Tom Vercauteren, Sebastien Ourselin, and M. Jorge Cardoso.“Generalised Dice Overlap as a Deep Learning Loss Function for Highly Unbalanced Segmentations.”In Deep Learning in Medical Image Analysis and Multimodal Learning for Clinical Decision Support, edited by M. Jorge Cardoso, Tal Arbel, Gustavo Carneiro, Tanveer Syeda-Mahmood, João Manuel R.S. Tavares, Mehdi Moradi, Andrew Bradley, et al., 10553:240–48.Cham:Springer International Publishing, 2017. https://doi.org/10.1007/978-3-319-67558-9_28.

[5] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox."U-Net:Convolutional Networks for Biomedical Image Segmentation.”In Medical Image Computing and Computer-Assisted Intervention – MICCAI 2015, edited by Nassir Navab, Joachim Hornegger, William M. Wells, and Alejandro F. Frangi, 9351:234–41.Cham:Springer International Publishing, 2015. https://doi.org/10.1007/978-3-319-24574-4_28.

另请参阅

| | | (Computer Vision Toolbox) | (Deep Learning Toolbox) | (Computer Vision Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox)

相关主题

外部网站