- “softmaxLayer” function: https://www.mathworks.com/help/deeplearning/ref/nnet.cnn.layer.softmaxlayer.html
- “pixelClassificationLayer” function: https://www.mathworks.com/help/vision/ref/nnet.cnn.layer.pixelclassificationlayer.html
How can i do multi-output for u-net model
10 次查看(过去 30 天)
显示 更早的评论
Hi, i just want to know did images regression that create with u-net model can train/predict multi-output? If it's can how can i do.
For my model, i created and trained the network by using Train Convolutional Neural Network for Regression as a guideline which is it have only one output but I need to create the model with the same artchitecture but more output. The multi-ouput that i need to train create from the original output which split into groups based on the values in the array as shown in the picture below.
Note: My old model predict the original image and my new model that i need to try, i want it predict the output 1 to output 4.
0 个评论
回答(1 个)
Aishwarya
2023-11-3
Hello,
As per my understanding, you have created a U-Net model using the below mentioned document as reference. Now you wish to create a multi-class U-Net model and split the output mask for each class.
As the mentioned documentation provides a simple convolution neural network for regression which outputs a single continuous value, I assume you have added up-sampling layers and skip-connections to make the network into U-Net architecture. To modify the network into multi-class U-Net model, consider changing the the last few layers after the last “relu” layer as show in example code below:
conv = convolution2dLayer(1, numClasses); % numClasses is the number of output classes
softmax = softmaxLayer();
outputLayer = pixelClassificationLayer();
After getting the output mask (“labelled_mask”) from the U-Net model, each class output can be separated using the example code below.
% Extract each class into a separate image
num_classes = 4;
for i = 1:num_classes
% Create binary mask for current class
class_mask = (labeled_mask == i);
% Apply binary mask to input image
class_img = img .* uint8(class_mask);
% Show output image
figure,
imshow(class_img);
end
Please refer to below MathWorks documentation for more details about the functions used:
I hope this helps!
0 个评论
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Pattern Recognition and Classification 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!