Conflict between SVM classifier and perfcurve()

7 次查看(过去 30 天)
I want to do binary classification using SVM and evaluate its performance using ROC by the function perfcurve(). Meanwhile I want to find out which is optimal threshold that separates these two classes and wheter feature > threshold is classified into class 1 or feature < threshold is classified into class 1.
Mdl = fitcsvm(X,Y); % train a model
[Label,Score] = predict(Mdl, X_new); % predict new data
[X,Y,T,AUC,OPTROCPT] = perfcurve(labels,Scores(:, 2), posclass); % ROC
My data X is a 100-by-1 vector, i.e. 100 observations and each observation has only one feature.
The predict() function will give a label based on the score. The score is peculiar, compared to the scores given by other classifiers. Generally, the classifier will give only one score for each observation and give a threshold on score to do classification. However, the predict() does not. In my binary classification case , the score is a n-by-2 matrix, accoding to matlab documentaion, "For each observation in X, the predicted class label corresponds to the maximum score among all classes."
In my case, I found that
Scores(:, 2) = -Scores(:, 1)
So
Label = Scores(:, 2) > Scores(:, 1);
or
Label = Scores(:, 2) < Scores(:, 1);
Here comes the problem, I can not get the optimal threshold that separates these two classes. Maybe you want to say that 0 is the threshold, score > 0 is classified into one class, and score < 0 is classified into another one. It seems plausiable, but contradicts perfcurve().
To find out the optimal threshold, I use perfcurve(). One of its output OPTROCPT is "Optimal operating point of the ROC curve, returned as a 1-by-2 array with false positive rate (FPR) and true positive rate (TPR) values for the optimal ROC operating point."
So I calculate the optimal threshold as follow
optIndex = find(X==OPTROCPT(1) & Y==OPTROCPT(2));
optThresh = T(optIndex);
This optThresh is different from optimal threshold given by SVM.
This code below can reproduce this problem
% conflict between SVM classifier and perfcurve()
%% load data
load fisheriris
inds = ~strcmp(species,'setosa'); % use two species to do binary classification
X = meas(inds, 1); % use one feature only
y = species(inds);
%% get train data and validation data
train_inds = true([100, 1]);
train_inds(1:25) = false;
train_inds(51:75) = false;
val_inds = ~train_inds ;
X_train = X(train_inds);
y_train = y(train_inds);
X_val = X(val_inds);
y_val = y(val_inds);
%% SVM
SVMModel = fitcsvm(X_train,y_train);
[Label,Score] = predict(SVMModel, X_val); % predict on new data
[X,Y,T,AUC,OPTROCPT] = perfcurve(y_val,Score(:, 2), SVMModel.ClassNames{2}); % ROC
optIndex = find(X==OPTROCPT(1) & Y==OPTROCPT(2));
optThresh = T(optIndex);
isequal(optThresh, 0)

回答(1 个)

Song Gao
Song Gao 2021-5-27
I think the problem here is you used the validation dataset to determine the optimal point.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by