Why am I unable to get test accuracy higher than 75% from my ANN?

5 次查看(过去 30 天)
Hello. I am working on developing an ANN to train a part of the MNIST dataset. I am trying to use the "trainNetwork" command on MATLAB, as it allows me to robustely create my model architecture. I also want to implement hyperparameter tuning using "bayesopt'. The hyperparameters I am trying to tune are the number of neuons per layer, the learning rate, and the type of optimizer used. My aim is to maximize the test accuracy. The code is working correctly I believe. However, I am unable to get test accuracy results above 75%. I am heavily regularizing the model as it had a strong tendency to overfit, and my code is shown below. Am I going wrong somewhere? If not, why am I unable to squeeze any accuracy greater than 75%?
clear; clc;
rng(1); % Do not modify this line
dir = "Datadir"; % for example: "C:\Users\username\Desktop\tmnist_bundle_rgb\tmnist_bundle_rgb"
imdir = fullfile(dir,"imgs");
% Step 1: Read and store the images inside the img path
datapath = readtable(dir+"\index.csv");
for i=1:height(datapath)
rawlabels(1,i) = datapath{i,1};
img = imread(fullfile(dir,datapath{i,2}));
img = rgb2gray(img);
features(:,i) = reshape(img,[],1);
end
encode = flip(eye(10));
labels = categorical(rawlabels);
features = normalize(features,'range');
cv = cvpartition(labels,'HoldOut',0.3, 'Stratify',true);
idx = cv.test;
X_train = double(features(:,~idx)');
y_train = labels(:,~idx)';
X_test = double(features(:,idx)');
y_test = labels(:,idx)';
% Define the objective function for Bayesian optimization
objective = @(x) -train_network(x.Layer1, x.Layer2, x.Layer3, x.Layer4, 'adam', x.LearningRate, X_train, y_train, ...
X_test,y_test);
% Define the search space for hyperparameters
layer = {'16','32','64','128','256'};
vars = [
optimizableVariable('Layer1', layer, 'Type', 'categorical'),
optimizableVariable('Layer2', layer, 'Type', 'categorical'),
optimizableVariable('Layer3', layer, 'Type', 'categorical'),
optimizableVariable('Layer4', layer, 'Type', 'categorical'),
optimizableVariable('Optimizer', {'adam', 'sgdm', 'rmsprop'}, 'Type', 'categorical'),
optimizableVariable('LearningRate', [1e-5, 1e-3], 'Transform', 'log')
];
% Perform Bayesian optimization
results = bayesopt(objective,vars,...
'MaxObj',100,...
'MaxTime',8*60*60,...
'IsObjectiveDeterministic',false,...
'UseParallel',false);
% Retrieve the best hyperparameters and corresponding objective value
bestHyperparams = results.XAtMinObjective;
bestObjectiveValue = results.MinObjective;
% Display the best hyperparameters and objective value
disp('Best Hyperparameters:');
disp(bestHyperparams);
disp('Best Objective Value:');
disp(bestObjectiveValue);
% Define the objective function for training the network
function accuracy = train_network(layer1, layer2, layer3, layer4, optimizer, learningRate, X_train, y_train, X_test, y_test)
% Load your dataset and split it into training and validation sets
neurons = [16 32 64 128 256];
% Define the layers of the network based on the hyperparameters
layer1 = neurons(double(layer1));
layer2 = neurons(double(layer2));
layer3 = neurons(double(layer3));
layer4 = neurons(double(layer4));
if layer2>layer1 || layer3>layer1 || layer3>layer2 || layer4>layer3 || layer4>layer2 || layer4>layer1
accuracy = inf;
return;
end
layers = [
featureInputLayer(784)
fullyConnectedLayer(layer1)
reluLayer
batchNormalizationLayer
dropoutLayer(0.4)
fullyConnectedLayer(layer2)
reluLayer
batchNormalizationLayer
dropoutLayer(0.4)
fullyConnectedLayer(layer3)
reluLayer
batchNormalizationLayer
dropoutLayer(0.4)
fullyConnectedLayer(layer4)
reluLayer
batchNormalizationLayer
dropoutLayer(0.4)
fullyConnectedLayer(10)
softmaxLayer
classificationLayer
];
% Specify the training options
options = trainingOptions(optimizer, ...
'MaxEpochs', 100, ...
'MiniBatchSize', 16, ...
'InitialLearnRate', learningRate);
% Train the network
net = trainNetwork(X_train, y_train, layers, options);
% Evaluate the network on the validation set and calculate accuracy
predictions = classify(net, X_test);
accuracy = sum(predictions == y_test) / numel(y_test)
end

采纳的回答

Rohit
Rohit 2023-5-16
Hi Zeyad,
Based on the code you provided, it seems that you are correctly implementing the hyperparameter tuning using Bayesian optimization and training your network using trainNetwork in MATLAB.
However, there can be few potential reasons of why you are not getting test accuracy above 75%:
  1. Limited model capacity: The network architecture you defined has a shallow structure with a maximum of four fully connected layers. For complex datasets like MNIST, deeper architectures may be necessary to capture more intricate patterns and achieve higher accuracy.
  2. Insufficient training data: It is important to have enough training data to train a deep learning model effectively. In your code, you are using a 70:30 train-test split, which might result in limited data for training. Consider using more training samples, if possible, to provide the model with a larger and more diverse set of examples.
  3. Hyperparameter values: Although Bayesian optimization helps in finding good hyperparameter values, it is still possible that the optimal values are not found within the search space defined by your hyperparameters. Experiment with different ranges and values for the hyperparameters to see if you can achieve better results.
It is important to acknowledge that achieving higher accuracy on the MNIST dataset is indeed possible, with many models able to achieve test accuracy well above 90%. So, exploring more advanced network architectures, such as convolutional neural networks (CNNs), can be advantageous, especially for image classification tasks like MNIST.
  1 个评论
Zeyad Elreedy
Zeyad Elreedy 2023-5-17
Thank you for your reply, @Rohit! As a follow up, I do understand that the complexity of MNIST and the shallowness of the model may be the reason due to the low accuracy. What surprised me, however, is when I used 'patternnet' instead of 'trainNetwork', and using layer sizes of [64 32 16]. I was able to achieve an accuracy of 92.4%, with the default settings (optimizer, no activation functions, etc.) of patternnet. Is there a reason why this is the case, or is my understanding of how patternnet works incorrect?

请先登录,再进行评论。

更多回答(0 个)

产品


版本

R2022a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by