Perform 3-D U-Net Deep Learning with non-cubic input
3 次查看(过去 30 天)
显示 更早的评论
Hi,
Currently I have used the tutorial "3-D Brain Tumor Segmentation Using Deep Learning" in https://it.mathworks.com/help/images/segment-3d-brain-tumor-using-deep-learning.html to create and train a 3-D U-Net neural network to segment kidneys out of 3D MRI scans. In general, the set-up of my script is the same, I have only changed a few things:
- My created U-Net has 5 Layers (instead of 4 in the original example)
- The training volumes have sizes of 64*80*32*1 (instead of patch sizes of 32*32*32*4). No patches were used, thus .mat volumes of 64*80*32*1 were used to train the 3-D U-Net.
- The training volumes were datatype ‘single’ (and were values between 0 and 1)
- The test volumes were datatype ‘logical’
The training options were:
options = trainingOptions('sgdm', ...
'MaxEpochs',40, ...
'InitialLearnRate',1e-2, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',5, ...
'LearnRateDropFactor',0.95, ...
'ValidationData',dsVal, ...
'ValidationFrequency',4, ...
'Plots','training-progress', ...
'Verbose',true, ...
'MiniBatchSize',1);
dsVal was a pixelLabelImageDatastore (instead of a randomPatchExtractionDatastore since I do not patch the data).
For testing, test data also consisted of 64*80*32*1 ‘single’ volumes, stored in an ImageDatastore (voldsTest). The ground truth Labels were logicals (with 0 for background and 1 for kidney) stored in a pixelLabelDataStore (pxdsTest).
groundTruthLabels{id} = read(pxdsTest);
vol{id} = read(voldsTest);
vol2 = vol{id};
tempSeg = semanticseg(vol2,net);
To make sure that the script functioned correctly, I have tested the script with generated dummy data (with circles randomly placed in a 64*80*32*1 matrix and noise added to this volume):
Figure 1. Left: middle slice of the 64*80*32*1 volume fed to the 3-D U-Net. Right: middle slice of the 64*80*32*1 ground truth label.
However, after training with this data I get peculiar results. Although accuracy levels were around 80% and loss levels were +- 0.3 on the validation data, the dice correlation coefficients on the test data is low. For visualising the predicted segmentations, I used:
% Visualisation of first test volume and predicted label
vol3d = vol{1}(:,:,:,1);
zID = 16;
zSliceGT = labeloverlay(vol3d(:,:,zID),groundTruthLabels{volId}(:,:,zID));
zSlicePred = labeloverlay(vol3d(:,:,zID),predictedLabels{volId}(:,:,zID));
In the visualisation, a typical striped pattern can be seen in the predicted label:
Figure 2. Left: middle slice of the ground truth 64*80*32*1 volume with the ground truth segmentation overlayed in light blue. Right: middle slice of the 64*80*32*1 test volume with predicted segmentation overlayed in light blue. The light blue striped pattern can be seen.
When only visualising the labels for another test volume, this striped pattern can also be seen:
Figure 3.
When I use 32*32*32*1 dummy data for training and I only change the inputLayer sizes of the network, the test segmentations look normal and I do not get this striped segmentation pattern.
Therefore, my question is: Does anyone know how this problem arises? It probably has to do something with the non-cubic input (64*80*32 instead of 32*32*32), does this mean that the MATLAB 3-D U-Net cannot handle non-cubic input or is there a mistake in my code?
0 个评论
回答(1 个)
Divya Gaddipati
2019-8-7
The input size might not be the problem, as you were able to train the network without any errors.
In case of medical imaging, “accuracy” is not a good indicator of network performance. Since, this is a segmentation task, try to check your network performance during training by calculating dice accuracy.
As you mentioned that “the dice correlation coefficients on the test data is low”, this means the network has not generalizing well and that the network could be overfitted or underfitted. To resolve this, try to increase the number of epochs and start with a lower learning rate.
Still if the network is not working, change the optimizer to “adam” with a small initial learning rate.
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Image Data Workflows 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!