Neural Network cross validation

1 次查看(过去 30 天)
I am new to matlab. I have implemented a character recognition system using neural networks.Now, I am trying to do a 10 fold cross validation scheme for neural networks. I have done the following code.But i dont know if it is correct. Pls help me.
close all
clear all
load inputdata
load targetdata
inputs = input;
targets = target;
% Create a Pattern Recognition Network
hiddenLayerSize = 30;
net = patternnet(hiddenLayerSize);
% Choose Input and Output Pre/Post-Processing Functions
% For a list of all processing functions type: help nnprocess
net.inputs{1}.processFcns = {'removeconstantrows','mapminmax'};
net.outputs{2}.processFcns = {'removeconstantrows','mapminmax'};
k=10;
groups=[1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 4 4 4 4 4 4 4 4 4 4 5 5 5 5 5 5 5 5 5 5 6 6 6 6 6 6 6 6 6 6 7 7 7 7 7 7 7 7 7 7 8 8 8 8 8 8 8 8 8 8 9 9 9 9 9 9 9 9 9 9 10 10 10 10 10 10 10 10 10 10 11 11 11 11 11 11 11 11 11 11 12 12 12 12 12 12 12 12 12 12 13 13 13 13 13 13 13 13 13 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13]; %target
cvFolds = crossvalind('Kfold', groups, k); %# get indices of 10-fold CV
for i = 1:k %# for each fold
testIdx = (cvFolds == i); %# get indices of test instances
trainIdx = ~testIdx ; %# get indices training instances
trInd=find(trainIdx)
tstInd=find(testIdx)
net.trainFcn = 'trainbr'
net.trainParam.epochs = 100;
net.divideFcn = 'divideind';
net.divideParam.trainInd=trInd
net.divideParam.testInd=tstInd
% Choose a Performance Function
net.performFcn = 'mse'; % Mean squared error
% Train the Network
[net,tr] = train(net,inputs,targets);
%# test using test instances
outputs = net(inputs);
errors = gsubtract(targets,outputs);
performance = perform(net,targets,outputs)
trainTargets = targets .* tr.trainMask{1};
testTargets = targets .* tr.testMask{1};
trainPerformance = perform(net,trainTargets,outputs)
testPerformance = perform(net,testTargets,outputs)
test(k)=testPerformance;
save net
figure, plotconfusion(targets,outputs)
end
accuracy=mean(test);
% View the Network
view(net)
  1 个评论
Greg Heath
Greg Heath 2014-3-13
When I cut and paste your code into the command line it does not run because it is not properly formatted.
I suggest that you reformat your post so that it will run when cut and pasted.
However, I do not have crossvalind or crossperf, so I'm not sure how much help I can be.
Reformat and I will try to do what I can.

请先登录,再进行评论。

采纳的回答

Greg Heath
Greg Heath 2014-3-14
编辑:Greg Heath 2014-3-14
Formatting not perfect; Did you cut and paste this version?
Results?
size(inputs) = ?
size(targets) = ?
Did you try to minimize hiddenlayer size?
Take net.divideFcn and net.trainFcn out of the loop
Trainbr uses regularization, not ordinary mse
Did you try default values of the remaining net.* specifications before overwriting them?
Initialize the RNG just before the loop so you can repeat your run if needed.
You have to configure the net at the top of the loop otherwise, weight initialization will only occur for the 1st net.
The train and test performances are already in tr. No need to recalculate.
You may have to use the Masks on both targets and outputs. Check to make sure.
If you save each net, they have to have different names. A 10 dimensional cell should work.
Why not save trainperformance also.
Then calculate min,median,mean,std and max of both train and test performances.
Why not run the iris_dataset before trying your own data?
If you use crossval and cvpartition we could compare results. HOWEVER, although I have them, I have never used them. it might just be easier if you used my crossval code in the NEWSGROUP
Also, index your confusion plots; otherwise they will overwrite the previous one
Hope this helps.
Thank you for formally accepting my answer
Greg

更多回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Sequence and Numeric Feature Data Workflows 的更多信息

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by