CNNへの交差検定(​Cross-Vali​dation)の導入​の仕方

6 次查看(过去 30 天)
ssk
ssk 2019-2-7
编辑: ssk 2019-2-11
プログラミング初心者です。
現在、チュートリアルのコードを微修正して動かしており、以下のコードに交差検定の追加を検討しております。
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.5,'randomize');
help crossvarで検索すると、以下のようにでてきました。
TESTVAL = FUN(XTRAIN,XTEST)
こちらを、TESTVAL = FUN(imdsTrain, imdsValidation)とすると交差検定を導入できるという認識で
コンパイルしたのですが動きませんでした。
Undefined function or variable 'FUN'.
というエラーが出てしまいます。
交差検定の正しいやり方につきましてご教示いただけますと幸いです。
どうぞよろしくお願いいたします。

采纳的回答

Tohru Kikawada
Tohru Kikawada 2019-2-9
crossvalのドキュメントに記載のある下記は指定する関数の戻り値と引数の一例です。
TESTVAL = FUN(XTRAIN,XTEST)
ドキュメントにあるいくつかの例題は試してみましたでしょうか。crossvalは様々な機械学習のアルゴリズムで使えるように汎用性のある関数ハンドルの受け渡しで実行されます。CNNで交差検定を実行する場合も下記のようにCNNのクラス分類結果を返すような関数を関数ハンドルとして渡してあげる必要があります。
%% データセットの読み込み
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
%% ダミーのトレーニングインデックスを生成
X = (1:imds.numpartitions)';
y = imds.Labels;
%% 交差検定にCNNの予測ラベル関数のポインタを渡す
mcr = crossval('mcr',X,y,'Predfun',@(xtrain,ytrain,xtest)myCNNPredict(xtrain,ytrain,xtest,imds))
%% CNNを学習し、予測ラベルを出力する関数
function ypred = myCNNPredict(xtrain,ytrain,xtest,imds)
% 結果が一意になるように乱数シードをデフォルト値に設定
rng('default');
% ダミーの変数ベクトルを受けてimageDatastoreを学習用とテスト用に分割
imdsTrain = imageDatastore(imds.Files(xtrain));
imdsTrain.Labels = ytrain;
imdsValidation = imageDatastore(imds.Files(xtest));
% レイヤーの設定
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.01, ...
'MaxEpochs',4, ...
'Shuffle','every-epoch', ...
'Verbose',false);
net = trainNetwork(imdsTrain,layers,options);
ypred = classify(net,imdsValidation);
end
  4 个评论
ssk
ssk 2019-2-11
ご回答ありがとうございます。おかげさまでチュートリアルのコードを無事、コンパイルすることができました。ありがとうございます。
DICOMファイルでも交差検定が使えるかどうか試したところ、以下のようなエラーが出てしまいます。
The function '@(xtrain,ytrain,xtest)myCNNPredict(xtrain,ytrain,xtest,imds)' generated
the following error:
Input folders or files contain non-standard file extensions.
拡張子が違うのが原因かもしれません。
currentdirectory = pwd;
% set categories of subdirectory
categories = {'a', 'b', 'c','d'};
imds = imageDatastore(fullfile(currentdirectory, categories),'IncludeSubfolders',true,'FileExtensions','.dcm','LabelSource', 'foldernames','ReadFcn',@dicomread);
mcr = crossval('mcr',X,y,'Predfun',(xtrain,ytrain,xtest)myCNNPredict(xtrain,ytrain,xtest,imds))
作成したコードは上記の通りですが、DICOMファイルでの交差検定の仕方につきまして、ご教示頂けますと幸いです。
どうぞよろしくお願いいたします。
ssk
ssk 2019-2-11
编辑:ssk 2019-2-11
五月雨式のコメント失礼いたします。
頂いた回答につきまして以下の質問がございます。
%% ダミーのトレーニングインデックスを生成
X = (1:imds.numpartitions)';
(1)なぜ、ダミーのトレーニングインデックスを生成しているのか、
(2)なぜ、numpartitions(おそらくnumber of partition)を使っているのか、
(3)(1:imds.numpatition)の意味につきましてもご教示いただけますと幸いです。
@(xtrain,ytrain,xtest)myCNNPredict(xtrain,ytrain,xtest,imds)
また、mcrの意味につきましては、 misclassification rateの略語という意味でお間違えないでしょうか。
どうぞよろしくお願いいたします。

请先登录,再进行评论。

更多回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Deep Learning Toolbox 的更多信息

标签

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!