How to make a ResNet network with more than 3 channels?

28 次查看(过去 30 天)
This is in relation to the move from LayerGraphs, trainNetwork, and resnetLayers to dlnetworks, trainnet, and resnetNetwork.
I have a dataset consisting of 125x125 images, with 8 pixel channels instead of the usual 3 for RGB. With the older resnetLayers (which creates a ResNet LayerGraph) and trainNetwork, I was able to store this data as cells in a table and train with that. However, resnetLayers and trainNetwork are no longer recommended by MATLAB, and they recommend using resnetNetwork (which creates a ResNet dlnetwork) and trainnet instead. I want to keep up with the more modern and supported implementations, so I tried rewriting my code to use these, but I run into a problem where it seems trainnet won't accept this format of data anymore, and I haven't yet found a way to make it work. How should I go about passing this data to trainnet?
%Load the training data
load("dataset.mat", "traindata");
traindata
traindata = 7x2 table
Features Response __________________ __________ {125x125x8 double} Signal {125x125x8 double} Background {125x125x8 double} Background {125x125x8 double} Background {125x125x8 double} Signal {125x125x8 double} Background {125x125x8 double} Signal
%Train a ResNet LayerGraph using the old functions
imageSize = [125 125 8];
numClasses = 2;
layers = resnetLayers(imageSize,numClasses);
opts = trainingOptions("adam",...
"ExecutionEnvironment","cpu",...
"InitialLearnRate",0.0001,...
"MaxEpochs",1);
trainedNetwork = trainNetwork(traindata,layers,opts)
Initializing input data normalization. |========================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Mini-batch | Base Learning | | | | (hh:mm:ss) | Accuracy | Loss | Rate | |========================================================================================| | 1 | 1 | 00:00:00 | 28.57% | 0.8097 | 1.0000e-04 | |========================================================================================| Training finished: Max epochs completed.
trainedNetwork =
DAGNetwork with properties: Layers: [177x1 nnet.cnn.layer.Layer] Connections: [192x2 table] InputNames: {'input'} OutputNames: {'output'}
%Try to train a ResNet dlnetwork with the new functions, get an error
dlresnet = resnetNetwork(imageSize,numClasses);
dlresnet = initialize(dlresnet);
trainedDlnetwork = trainnet(traindata,dlresnet,"crossentropy",opts)
Error using trainnet (line 46)
For table input, the data must be feature data and the network must have a single input and a single output.

回答(1 个)

Aditya
Aditya 2024-6-26,5:51
编辑:Aditya 2024-6-26,5:52
To transition from using 'resnetLayers' and 'trainNetwork' to 'resnetNetwork' and 'trainnet', you need to ensure that your data is in the correct format for the new functions. The new trainnet function expects the data to be in a specific format, particularly when using tables.
Steps to Follow
  1. Convert Data to 'dlarray' Format: The trainnet function expects data to be in the form of 'dlarray' objects when working with deep learning networks.
  2. Prepare the Data: Convert your images and labels into a format that trainnet can accept. The images should be in a 4D array (Height x Width x Channels x Batch) and the labels should be in a categorical array.
  3. Define the Network: Use 'resnetNetwork' to define your ResNet network.
  4. Train the Network: Use 'trainnet' to train the network with the formatted data.
sample code:
% Load the training data
load("dataset.mat", "traindata");
% Extract features and responses from the table
features = cat(4, traindata.Features{:}); % Concatenate along the 4th dimension to form a 4D array
responses = categorical(traindata.Response); % Convert responses to categorical
% Convert to dlarray
dlFeatures = dlarray(features, 'SSCB'); % 'SSCB' -> Spatial, Spatial, Channel, Batch
% Define the network
imageSize = [125 125 8];
numClasses = 2;
dlresnet = resnetNetwork(imageSize, numClasses);
dlresnet = initialize(dlresnet);
% Specify training options
opts = trainingOptions("adam", ...
"ExecutionEnvironment", "cpu", ...
"InitialLearnRate", 0.0001, ...
"MaxEpochs", 1);
% Train the network
trainedDlnetwork = trainnet(dlFeatures, responses, dlresnet, "crossentropy", opts);
Note:
  • Data Normalization: Ensure that your data is normalized if needed.
  • Batch Size: Adjust the batch size in 'trainingOptions' if you encounter memory issues.
To read more about 'trainnet' function, please refer to the below documentation:

类别

Help CenterFile Exchange 中查找有关 Image Data Workflows 的更多信息

产品


版本

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by