Main Content

本页的翻译已过时。点击此处可查看最新英文版本。

使用贝叶斯优化来优化 SVM 分类器拟合

此示例说明如何使用 fitcsvm 函数和 OptimizeHyperparameters 名称-值对组优化 SVM 分类。该分类基于高斯混合模型中点的位置来工作。有关该模型的描述,请参阅 The Elements of Statistical Learning,作者 Hastie、Tibshirani 和 Friedman (2009),第 17 页。该模型从为“green”类生成 10 个基点开始,这些基点呈二维独立正态分布,均值为 (1,0) 且具有单位方差。它还为“red”类生成 10 个基点,这些基点呈二维独立正态分布,均值为 (0,1) 且具有单位方差。对于每个类(green 和 red),生成 100 个随机点,如下所示:

  1. 随机均匀选择合适颜色的一个基点 m

  2. 生成一个呈二维正态分布的独立随机点,其均值为 m,方差为 I/5,其中 I 是 2×2 单位矩阵。在此示例中,使用方差 I/50 来更清楚地显示优化的优势。

生成点和分类器

为每个类生成 10 个基点。

rng default % For reproducibility
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

查看基点。

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

Figure contains an axes. The axes contains 2 objects of type line.

由于一些红色基点靠近绿色基点,因此很难仅基于位置对数据点进行分类。

生成每个类的 100 个数据点。

redpts = zeros(100,2);grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

查看数据点。

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

Figure contains an axes. The axes contains 2 objects of type line.

为分类准备数据

将数据放入一个矩阵中,并创建向量 grp,该向量标记每个点的类。

cdata = [grnpts;redpts];
grp = ones(200,1);
% Green label 1, red label -1
grp(101:200) = -1;

准备交叉验证

为交叉验证设置一个分区。此步骤确定优化在每一步所使用的训练集和测试集。

c = cvpartition(200,'KFold',10);

优化拟合

要找到好的拟合,即交叉验证损失低的拟合,请设置选项以使用贝叶斯优化。在所有优化中使用相同的交叉验证分区 c

为了实现可再现性,请使用 'expected-improvement-plus' 采集函数。

opts = struct('Optimizer','bayesopt','ShowPlots',true,'CVPartition',c,...
    'AcquisitionFunctionName','expected-improvement-plus');
svmmod = fitcsvm(cdata,grp,'KernelFunction','rbf',...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions',opts)
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | BoxConstraint|  KernelScale |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|    1 | Best   |       0.345 |     0.21295 |       0.345 |       0.345 |      0.00474 |       306.44 |
|    2 | Best   |       0.115 |     0.17149 |       0.115 |     0.12678 |       430.31 |       1.4864 |
|    3 | Accept |        0.52 |      0.1333 |       0.115 |      0.1152 |     0.028415 |     0.014369 |
|    4 | Accept |        0.61 |      0.2691 |       0.115 |     0.11504 |       133.94 |    0.0031427 |
|    5 | Accept |        0.34 |     0.35421 |       0.115 |     0.11504 |     0.010993 |       5.7742 |
|    6 | Best   |       0.085 |     0.11744 |       0.085 |    0.085039 |       885.63 |      0.68403 |
|    7 | Accept |       0.105 |     0.11036 |       0.085 |    0.085428 |       0.3057 |      0.58118 |
|    8 | Accept |        0.21 |     0.13858 |       0.085 |     0.09566 |      0.16044 |      0.91824 |
|    9 | Accept |       0.085 |       0.178 |       0.085 |     0.08725 |       972.19 |      0.46259 |
|   10 | Accept |         0.1 |      0.3127 |       0.085 |    0.090952 |       990.29 |        0.491 |
|   11 | Best   |        0.08 |     0.12274 |        0.08 |    0.079362 |       2.5195 |        0.291 |
|   12 | Accept |        0.09 |     0.11783 |        0.08 |     0.08402 |       14.338 |      0.44386 |
|   13 | Accept |         0.1 |     0.10892 |        0.08 |     0.08508 |    0.0022577 |      0.23803 |
|   14 | Accept |        0.11 |     0.10798 |        0.08 |    0.087378 |       0.2115 |      0.32109 |
|   15 | Best   |        0.07 |     0.17553 |        0.07 |    0.081507 |        910.2 |      0.25218 |
|   16 | Best   |       0.065 |     0.17637 |       0.065 |    0.072457 |       953.22 |      0.26253 |
|   17 | Accept |       0.075 |     0.26422 |       0.065 |    0.072554 |       998.74 |      0.23087 |
|   18 | Accept |       0.295 |     0.11691 |       0.065 |    0.072647 |       996.18 |       44.626 |
|   19 | Accept |        0.07 |     0.13689 |       0.065 |     0.06946 |       985.37 |      0.27389 |
|   20 | Accept |       0.165 |     0.11689 |       0.065 |    0.071622 |     0.065103 |      0.13679 |
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | BoxConstraint|  KernelScale |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|   21 | Accept |       0.345 |     0.10901 |       0.065 |    0.071764 |        971.7 |       999.01 |
|   22 | Accept |        0.61 |     0.12691 |       0.065 |    0.071967 |    0.0010168 |    0.0010005 |
|   23 | Accept |       0.345 |     0.10781 |       0.065 |    0.071959 |    0.0010674 |       999.18 |
|   24 | Accept |        0.35 |     0.11365 |       0.065 |    0.071863 |    0.0010003 |       40.628 |
|   25 | Accept |        0.24 |      0.2432 |       0.065 |    0.072124 |       996.55 |       10.423 |
|   26 | Accept |        0.61 |      0.1947 |       0.065 |    0.072068 |       958.64 |    0.0010026 |
|   27 | Accept |        0.47 |      0.1147 |       0.065 |     0.07218 |       993.69 |     0.029723 |
|   28 | Accept |         0.3 |     0.10471 |       0.065 |    0.072291 |       993.15 |       170.01 |
|   29 | Accept |        0.16 |      0.2536 |       0.065 |    0.072104 |       992.81 |       3.8594 |
|   30 | Accept |       0.365 |     0.10891 |       0.065 |    0.072112 |    0.0010017 |     0.044287 |

Figure contains an axes. The axes with title Min objective vs. Number of function evaluations contains 2 objects of type line. These objects represent Min observed objective, Estimated min objective.

Figure contains an axes. The axes with title Objective function model contains 5 objects of type line, surface, contour. These objects represent Observed points, Model mean, Next point, Model minimum feasible.

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 50.7144 seconds
Total objective function evaluation time: 4.9196

Best observed feasible point:
    BoxConstraint    KernelScale
    _____________    ___________

       953.22          0.26253  

Observed objective function value = 0.065
Estimated objective function value = 0.073726
Function evaluation time = 0.17637

Best estimated feasible point (according to models):
    BoxConstraint    KernelScale
    _____________    ___________

       985.37          0.27389  

Estimated objective function value = 0.072112
Estimated function evaluation time = 0.1773
svmmod = 
  ClassificationSVM
                         ResponseName: 'Y'
                CategoricalPredictors: []
                           ClassNames: [-1 1]
                       ScoreTransform: 'none'
                      NumObservations: 200
    HyperparameterOptimizationResults: [1x1 BayesianOptimization]
                                Alpha: [77x1 double]
                                 Bias: -0.2352
                     KernelParameters: [1x1 struct]
                       BoxConstraints: [200x1 double]
                      ConvergenceInfo: [1x1 struct]
                      IsSupportVector: [200x1 logical]
                               Solver: 'SMO'


  Properties, Methods

计算优化模型的损失。

lossnew = kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,'KernelFunction','rbf',...
    'BoxConstraint',svmmod.HyperparameterOptimizationResults.XAtMinObjective.BoxConstraint,...
    'KernelScale',svmmod.HyperparameterOptimizationResults.XAtMinObjective.KernelScale))
lossnew = 0.0650

该损失与“观测的目标函数值”下优化输出中报告的损失相同。

可视化经过优化的分类器。

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)),...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(svmmod,xGrid);
figure;
h = nan(3,1); % Preallocation
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(svmmod.IsSupportVector,1),...
    cdata(svmmod.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');
axis equal
hold off

Figure contains an axes. The axes contains 4 objects of type line, contour. These objects represent -1, +1, Support Vectors.

另请参阅

|

相关主题