How to Improve Performance of an Over-fit Convolutional Neural Network

3 次查看(过去 30 天)
Hi There,
I've trained a 2D convolutional neural network using 60 grayscale images, with 5 different target classifications (using the following database, identifying 'happy', 'sad', 'sleepy', 'normal', and 'surprised'). There are 15 remaining images in my sample for validation and/or testing.
I created my CNN using the Deep Network Designer with the following layers (in order): the image input, 2 neural layers, a drop-out layer, a relu layer, a batch normalization layer, a fully connected layer, and a softmax layer.
When I call the function trainingOptions and set the parameters of options to the following values, I'm able to achieve basically 100% prediction accuracy for the training images:
max Epochs = 100 (this could be reduced to 40)
Learning Rate Schedule = Constant
Learning Rate = 0.0005
Validation Frequency = 30
Validation Patience = Inf
Objective Metric = Loss
Output Metric = Best validation
However, even when I re-run the training many times, the peformance of the CNN at predicting the correct classification for the validation images never seems to rise above 20%. Perplexingly, the prediction accuracy seems to go down with continued training, as you can see here:
Also, the predictions seem to occur at a higher frequency with certain classifications over others, as you can see here:
I've been told that my CNN may be over-fit, and that I may try to address that potentially in the following two ways:
  • By enforcing regularization penalty for generalization. I'm still figuring out what functions or options may allow me to do that.
  • By using early-stopping. I tried to do that by increasing the LearnRateDropFactor from 0.2 to 0.8, decreasing the LearnRateDropPeriod from 5 to 1 (and to 10) and decreasing the ValidationFrequency from 30 to 3. I'm not sure if these measures could truly be considered early-stopping though.
Could someone please help me understand what else I may need to try, and try to give me a clearer understanding of what could be going on here?
Thanks.

回答(1 个)

Matt J
Matt J 2024-5-9
编辑:Matt J 2024-5-9
By enforcing regularization penalty for generalization. I'm still figuring out what functions or options may allow me to do that.
That would probably mean experimenting with the L2Regularization training parameter. You should also use analyzeNetwork() to see how many unknown learnable parameters you have.
Aside from that, since you have only 60 training samples, you would need to keep the number of unknowns fairly small. You haven't shown the code for your network architecutre, so everything is in the dark, but I am guessing you have at least half a million unknowns in the fully connected layer alone. You probably need to add some pooling layers or increase the stride of your convolutions to decrease this number.
It's also not clear if you are doing any data augmentation, so see,

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by