Cross-validation of single binary learners in multiclass classification (fitcecoc)

1 次查看(过去 30 天)
I am training a multiclass classification model based on SVM using the function fitcecoc with coding design 'allpairs', meaning that binary models are trained for all possible combinations of class pairs.
You can cross-validate this multiclass (ECOC) classifier and estimate its generalization error by for example doing:
Mdl = fitcecoc(X,Y,'Learners',t,...
'ClassNames',{'setosa','versicolor','virginica'});
CVMdl = crossval(Mdl);
oosLoss = kfoldLoss(CVMdl)
In addition to this, would it also be possible to cross-validate and estimate the generalization error for the single binary models?

采纳的回答

Shubham
Shubham 2024-9-5
Hi Alessandro,
Yes, it is possible to cross-validate and estimate the generalization error for each of the individual binary models within an ECOC (Error-Correcting Output Codes) multiclass classification framework in MATLAB. However, MATLAB does not provide a direct built-in function to perform cross-validation on each individual binary model separately when using fitcecoc with the 'allpairs' coding design.
Approach:
To achieve this, you can manually extract the binary models and cross-validate each one separately. Here's how you can do it:
  1. Train the ECOC Model: Use fitcecoc with the 'allpairs' coding design to train your multiclass model.
  2. Extract Binary Models: Access the binary learners from the trained ECOC model.
  3. Cross-Validate Each Binary Model: Use cross-validation on each binary classifier separately.
Here is a step-by-step example:
% Load example data
load fisheriris
X = meas;
Y = species;
% Train the ECOC model with all-pairs coding design
t = templateSVM('KernelFunction', 'linear');
Mdl = fitcecoc(X, Y, 'Learners', t, 'ClassNames', {'setosa', 'versicolor', 'virginica'}, 'Coding', 'allpairs');
% Extract binary models
binaryModels = Mdl.BinaryLearners;
% Initialize variable to store cross-validation losses for each binary model
binaryLosses = zeros(length(binaryModels), 1);
% Cross-validate each binary model
for i = 1:length(binaryModels)
% Extract data for the current binary problem
binaryModel = binaryModels{i};
classNames = binaryModel.ClassNames;
% Create a logical vector for the classes involved in the current binary model
isClass = ismember(Y, classNames);
% Subset the data for the current binary classification
XBinary = X(isClass, :);
YBinary = Y(isClass);
% Cross-validate the binary model
CVBinaryMdl = crossval(binaryModel, 'X', XBinary, 'Y', YBinary);
binaryLosses(i) = kfoldLoss(CVBinaryMdl);
% Display the cross-validation loss for the current binary model
fprintf('Binary Model %d (%s vs %s) Cross-Validation Loss: %.4f\n', i, classNames{1}, classNames{2}, binaryLosses(i));
end
% Display the average cross-validation loss across all binary models
averageBinaryLoss = mean(binaryLosses);
fprintf('Average Cross-Validation Loss for Binary Models: %.4f\n', averageBinaryLoss);
Explanation:
  • Training the ECOC Model: We train the ECOC model using fitcecoc with the 'allpairs' coding design, which creates binary classifiers for each pair of classes.
  • Extracting Binary Models: The binary models are accessed through Mdl.BinaryLearners.
  • Cross-Validation: For each binary model, extract the relevant subset of data corresponding to the two classes involved in that binary classification, and perform cross-validation using crossval.
  • Binary Loss Calculation: Calculate and print the cross-validation loss for each binary model, as well as the average loss across all binary models.

更多回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Classification Ensembles 的更多信息

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by