Is there a way to holdout specific data?

23 次查看(过去 30 天)
Mark
Mark 2024-9-23,10:48
评论: Udit06 2024-9-23,15:07
I'm producing decision trees (both classification and regression) of my dataset and I wish to use a specific set of data as the training and a specific set for the testing. Is there a way to do this?
For example, say my dataset is consists of 100 rows, is there a way to tell the software to compute rows 1-75 as the training set and rows 76-100 as the test set?
Thanks in advance

回答(1 个)

Udit06
Udit06 2024-9-23,10:57
Hi Mark,
You can use the array indexing to specify the training and testing sets using indexes. Please find below the code snippet to achieve the same:
% Define the training and testing indices
trainIndices = 1:75;
testIndices = 76:100;
% Split the data
trainData = data(trainIndices, :);
testData = data(testIndices, :);
I hope this helps.
  2 个评论
Mark
Mark 2024-9-23,13:06
Thanks for you help. Apologies if this next questions is a bit dumb, I'm still getting to grips with Matlab, but how do I then add that in to the code for computing the tree? For example, I use the below code:
MdlA = fitctree(input,output); % for unvalidated trees
MdlB = fitctree(input,output,'CrossVal','on') % for validated trees.
How in the case of MdlB, do I basically tell it to use rows 1-75 for the initial output, but 76-100 for the cross validation step?
Udit06
Udit06 2024-9-23,15:07
Hi Mark,
When you use the fitctree function with the 'CrossVal','on' option, MATLAB automatically performs cross-validation by splitting the data into multiple folds. You can find the same on the following MathWorks documentation:
However, if you want to manually specify the training and test sets, you should handle the splitting yourself rather than relying on the built-in cross-validation. You can find the code snippet on how to train the model using manually specifying train and test sets:
% Clear workspace
clear;
% Load the ionosphere dataset
load ionosphere;
% Define training data ratio and calculate number of training samples
trainRatio = 0.7;
numTrainSamples = round(trainRatio * size(X, 1));
% Split data into training and test sets
X_train = X(1:numTrainSamples, :);
Y_train = Y(1:numTrainSamples);
X_test = X(numTrainSamples+1:end, :);
Y_test = Y(numTrainSamples+1:end, :);
% Train a decision tree classifier
MdlA = fitctree(X_train, Y_train);
% Visualize the decision tree
view(MdlA, 'Mode', 'graph');
% Predict labels for the test set
Y_pred = predict(MdlA, X_test);
% Convert cell arrays to matrices for comparison
Y_test = cell2mat(Y_test);
Y_pred = cell2mat(Y_pred);
% Calculate and display accuracy
accuracy = sum(Y_test == Y_pred) / length(Y_test);
fprintf('Test Set Accuracy: %.2f%%\n', accuracy * 100);
Test Set Accuracy: 89.52%
I hope this helps.

请先登录,再进行评论。

类别

Help CenterFile Exchange 中查找有关 Get Started with Statistics and Machine Learning Toolbox 的更多信息

产品


版本

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by