How to get r-square,mean absolute error and mean square error after train neural network?

6 次查看(过去 30 天)
Hi all, I train neural network as follow command
net.divideFcn = 'dividerand'
net.divideParam.trainRatio= 0.6;
net.divideParam.testRatio= 0.2;
net.divideParam.valRatio= 0.2;
[net,tr]=train(net,input,target);
I want to get r-square,mean absolute error and mean square error from train,test and validation data
Cloud you please advice ?

回答(1 个)

Paras Gupta
Paras Gupta 2024-7-18,9:07
Hi Ninlawat,
I understand that you want to compute different network performance metrics on the train, test, and validation data after training a neural network object in MATLAB.
The following code illustrates one way to achieve the same:
% dummy data
input = rand(1, 100); % 1 feature, 100 samples
target = 2 * input + 1 + 0.1 * randn(1, 100); % Linear relation with some noise
% Define the feedforward network
net = feedforwardnet(10); % 10 hidden neurons
% Set up the data division
net.divideFcn = 'dividerand';
net.divideParam.trainRatio = 0.6;
net.divideParam.valRatio = 0.2;
net.divideParam.testRatio = 0.2;
% Train the network
[net, tr] = train(net, input, target);
% Get the network outputs
outputs = net(input);
% Separate the outputs for training, validation, and testing
trainOutputs = outputs(tr.trainInd);
valOutputs = outputs(tr.valInd);
testOutputs = outputs(tr.testInd);
% Separate the targets for training, validation, and testing
trainTargets = target(tr.trainInd);
valTargets = target(tr.valInd);
testTargets = target(tr.testInd);
% Calculate and display R-square, MAE, and MSE for each dataset
datasets = {'train', 'val', 'test'};
outputsList = {trainOutputs, valOutputs, testOutputs};
targetsList = {trainTargets, valTargets, testTargets};
for i = 1:length(datasets)
dataset = datasets{i};
outputs = outputsList{i};
targets = targetsList{i};
% R-square
SS_res = sum((targets - outputs).^2);
SS_tot = sum((targets - mean(targets)).^2);
R_square = 1 - SS_res / SS_tot;
% Mean Absolute Error (MAE)
MAE = mae(targets - outputs);
% Mean Square Error (MSE)
MSE = mse(net, targets, outputs);
% Display the results
fprintf('%s R-square: %.4f\n', dataset, R_square);
fprintf('%s MAE: %.4f\n', dataset, MAE);
fprintf('%s MSE: %.4f\n', dataset, MSE);
fprintf('\n');
end
You can refer the following documentation links for more infromation on the properties and functions used in the code above:
Hope this helps.

类别

Help CenterFile Exchange 中查找有关 Sequence and Numeric Feature Data Workflows 的更多信息

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by