how do i use k-fold cross validation in deep network designer?
33 次查看(过去 30 天)
显示 更早的评论
hello, i'm working for my project by using deep network designer to create U-net architecture model adapted of image regression. i need to do k-fold cross validation due to my train dataset;2D image pairs are small (just 30 pairs). can anybody please tell me what should i do for using k-fold for valid the model.
0 个评论
回答(1 个)
Rahul
2023-1-4
"deepNetworkDesigner" do not provide k-fold cross validation as such. But as a workaround, you can perform following steps:
Assuming your input images if of size 28x28x1000 and your labels of size 28x28x1000.
imgs: 28 x 28 x 1000
labels: 28 x 28 x 1000
imgs = randi([0,255], 28, 28, 1, 1000); % 1000 input images of size 28x28 single plane
labels = randi([0,255], 28, 28, 1, 1000); % 1000 labeled images of size 28x28 single plane
params.LR = 0.001;
params.Maxepochs = 100;
params.num_batch = 16;
layers = % your CNN model layers
options = trainingOptions('adam',...
'InitialLearnRate',params.LR,...
'MaxEpochs', params.Maxepochs,...
'MiniBatchSize', params.num_batch); % change this as per your requirement
%% k-fold cross validation
kfold_val = 10; % 10-fold cross validation value
num_samples = 1000; % = size(labels, 4);
fold = cvpartition(num_samples, 'kfold', kfold_val); % performing k-fold cross validation
for ii = 1:kfold_val
train_idx = fold.training(ii);
validation_idx = fold.test(ii);
% extract training images and labels using train_idx
xtrain = imgs(:, :, :, train_idx);
ytrain = labels(:, :, :, train_idx);
% extract validation images and labels using validation_idx
xvalid = imgs(:, :, :, validation_idx);
yvalid = labels(:, :, :, validation_idx);
% train the CNN model
trained_net = trainNetwork(xtrain, ytrain, layers, options);
% test on validation images
Pred = predict(trained_net, xvalid);
% calculate loss between yvalid and Pred
end
Please check documentation links below for your reference:
You can also go through below MATLAB central page for more information.
Please note that this workflow is NOT designed by MathWorks and contact the author in case of any issues.
另请参阅
类别
在 Help Center 和 File Exchange 中查找有关 Get Started with Deep Learning Toolbox 的更多信息
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!