使用贝叶斯优化来优化分类器拟合
此示例说明如何使用 fitcsvm
函数和 OptimizeHyperparameters
名称-值参数优化 SVM 分类。
生成数据
该分类基于高斯混合模型中点的位置来工作。有关该模型的描述,请参阅 The Elements of Statistical Learning,作者 Hastie、Tibshirani 和 Friedman (2009),第 17 页。该模型从为“green”类生成 10 个基点开始,这些基点呈二维独立正态分布,均值为 (1,0) 且具有单位方差。它还为“red”类生成 10 个基点,这些基点呈二维独立正态分布,均值为 (0,1) 且具有单位方差。对于每个类(green 和 red),生成 100 个随机点,如下所示:
随机均匀选择合适颜色的一个基点 m。
生成一个呈二维正态分布的独立随机点,其均值为 m,方差为 I/5,其中 I 是 2×2 单位矩阵。在此示例中,使用方差 I/50 来更清楚地显示优化的优势。
为每个类生成 10 个基点。
rng('default') % For reproducibility grnpop = mvnrnd([1,0],eye(2),10); redpop = mvnrnd([0,1],eye(2),10);
查看基点。
plot(grnpop(:,1),grnpop(:,2),'go') hold on plot(redpop(:,1),redpop(:,2),'ro') hold off
由于一些红色基点靠近绿色基点,因此很难仅基于位置对数据点进行分类。
生成每个类的 100 个数据点。
redpts = zeros(100,2); grnpts = redpts; for i = 1:100 grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02); redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02); end
查看数据点。
figure plot(grnpts(:,1),grnpts(:,2),'go') hold on plot(redpts(:,1),redpts(:,2),'ro') hold off
为分类准备数据
将数据放入一个矩阵中,并创建向量 grp
,该向量标记每个点的类。1 表示绿色类,-1 表示红色类。
cdata = [grnpts;redpts]; grp = ones(200,1); grp(101:200) = -1;
准备交叉验证
为交叉验证设置一个分区。
c = cvpartition(200,'KFold',10);
此步骤是可选的。如果您为优化指定一个分区,则您可以为返回的模型计算实际交叉验证损失。
优化拟合
要找到好的拟合,即具有使交叉验证损失最小化的最佳超参数的拟合,请使用贝叶斯优化。使用 OptimizeHyperparameters
名称-值参数指定要优化的超参数列表,并使用 HyperparameterOptimizationOptions
名称-值参数指定优化选项。
将 'OptimizeHyperparameters'
指定为 'auto'
。'auto'
选项包括一组典型的要优化的超参数。fitcsvm
查找 BoxConstraint
和 KernelScale
的最佳值。设置超参数优化选项,以使用交叉验证分区 c
并选择 'expected-improvement-plus'
采集函数以实现可再现性。默认采集函数取决于运行时间,因此可以给出不同结果。
opts = struct('CVPartition',c,'AcquisitionFunctionName','expected-improvement-plus'); Mdl = fitcsvm(cdata,grp,'KernelFunction','rbf', ... 'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions',opts)
|=====================================================================================================| | Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | BoxConstraint| KernelScale | | | result | | runtime | (observed) | (estim.) | | | |=====================================================================================================| | 1 | Best | 0.345 | 0.26612 | 0.345 | 0.345 | 0.00474 | 306.44 | | 2 | Best | 0.115 | 0.16757 | 0.115 | 0.12678 | 430.31 | 1.4864 | | 3 | Accept | 0.52 | 0.21336 | 0.115 | 0.1152 | 0.028415 | 0.014369 | | 4 | Accept | 0.61 | 0.41833 | 0.115 | 0.11504 | 133.94 | 0.0031427 | | 5 | Accept | 0.34 | 0.46056 | 0.115 | 0.11504 | 0.010993 | 5.7742 | | 6 | Best | 0.085 | 0.25465 | 0.085 | 0.085039 | 885.63 | 0.68403 | | 7 | Accept | 0.105 | 0.25751 | 0.085 | 0.085428 | 0.3057 | 0.58118 | | 8 | Accept | 0.21 | 0.28915 | 0.085 | 0.09566 | 0.16044 | 0.91824 | | 9 | Accept | 0.085 | 0.30816 | 0.085 | 0.08725 | 972.19 | 0.46259 | | 10 | Accept | 0.1 | 0.34457 | 0.085 | 0.090952 | 990.29 | 0.491 | | 11 | Best | 0.08 | 0.21805 | 0.08 | 0.079362 | 2.5195 | 0.291 | | 12 | Accept | 0.09 | 0.24212 | 0.08 | 0.08402 | 14.338 | 0.44386 | | 13 | Accept | 0.1 | 0.23766 | 0.08 | 0.08508 | 0.0022577 | 0.23803 | | 14 | Accept | 0.11 | 0.24347 | 0.08 | 0.087378 | 0.2115 | 0.32109 | | 15 | Best | 0.07 | 0.30411 | 0.07 | 0.081507 | 910.2 | 0.25218 | | 16 | Best | 0.065 | 0.24431 | 0.065 | 0.072457 | 953.22 | 0.26253 | | 17 | Accept | 0.075 | 0.33287 | 0.065 | 0.072554 | 998.74 | 0.23087 | | 18 | Accept | 0.295 | 0.21231 | 0.065 | 0.072647 | 996.18 | 44.626 | | 19 | Accept | 0.07 | 0.26876 | 0.065 | 0.06946 | 985.37 | 0.27389 | | 20 | Accept | 0.165 | 0.24669 | 0.065 | 0.071622 | 0.065103 | 0.13679 | |=====================================================================================================| | Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | BoxConstraint| KernelScale | | | result | | runtime | (observed) | (estim.) | | | |=====================================================================================================| | 21 | Accept | 0.345 | 0.20097 | 0.065 | 0.071764 | 971.7 | 999.01 | | 22 | Accept | 0.61 | 0.2416 | 0.065 | 0.071967 | 0.0010168 | 0.0010005 | | 23 | Accept | 0.345 | 0.26803 | 0.065 | 0.071959 | 0.0011459 | 995.89 | | 24 | Accept | 0.35 | 0.23608 | 0.065 | 0.071863 | 0.0010003 | 40.628 | | 25 | Accept | 0.24 | 0.39188 | 0.065 | 0.072124 | 996.55 | 10.423 | | 26 | Accept | 0.61 | 0.46697 | 0.065 | 0.072067 | 994.71 | 0.0010063 | | 27 | Accept | 0.47 | 0.28997 | 0.065 | 0.07218 | 993.69 | 0.029723 | | 28 | Accept | 0.3 | 0.24924 | 0.065 | 0.072291 | 993.15 | 170.01 | | 29 | Accept | 0.16 | 0.37085 | 0.065 | 0.072103 | 992.81 | 3.8594 | | 30 | Accept | 0.365 | 0.19017 | 0.065 | 0.072112 | 0.0010017 | 0.044287 |
__________________________________________________________ Optimization completed. MaxObjectiveEvaluations of 30 reached. Total function evaluations: 30 Total elapsed time: 42.2011 seconds Total objective function evaluation time: 8.4361 Best observed feasible point: BoxConstraint KernelScale _____________ ___________ 953.22 0.26253 Observed objective function value = 0.065 Estimated objective function value = 0.073726 Function evaluation time = 0.24431 Best estimated feasible point (according to models): BoxConstraint KernelScale _____________ ___________ 985.37 0.27389 Estimated objective function value = 0.072112 Estimated function evaluation time = 0.28413
Mdl = ClassificationSVM ResponseName: 'Y' CategoricalPredictors: [] ClassNames: [-1 1] ScoreTransform: 'none' NumObservations: 200 HyperparameterOptimizationResults: [1x1 BayesianOptimization] Alpha: [77x1 double] Bias: -0.2352 KernelParameters: [1x1 struct] BoxConstraints: [200x1 double] ConvergenceInfo: [1x1 struct] IsSupportVector: [200x1 logical] Solver: 'SMO' Properties, Methods
fitcsvm
返回使用最佳估计可行点的 ClassificationSVM
模型对象。最佳估计可行点是基于贝叶斯优化过程的基础高斯过程模型最小化交叉验证损失的置信边界上限的超参数集。
贝叶斯优化过程在内部维护目标函数的高斯过程模型。目标函数是分类的交叉验证误分类率。对于每次迭代,优化过程都会更新高斯过程模型并使用该模型找到一组新的超参数。迭代输出的每行显示新的超参数集和这些列值:
Objective
- 基于新的超参数集计算的目标函数值。Objective runtime
- 目标函数计算时间。Eval result
- 结果报告,指定为Accept
、Best
或Error
。Accept
表示目标函数返回有限值,Error
表示目标函数返回非有限实数标量值。Best
表示目标函数返回的有限值低于先前计算的目标函数值。BestSoFar(observed)
- 迄今为止计算的最小目标函数值。此值或者是当前迭代的目标函数值(如果当前迭代的Eval result
值是Best
),或者是前一个Best
迭代的值。BestSoFar(estim.)
- 在每次迭代中,软件使用更新后的高斯过程模型,基于迄今为止尝试的所有超参数集估计目标函数值的置信边界上限。然后,软件选择具有最小置信边界上限的点。BestSoFar(estim.)
值是predictObjective
函数在最小值点处返回的目标函数值。
迭代输出下方的图分别以蓝色和绿色显示 BestSoFar(observed)
和 BestSoFar(estim.)
值。
返回的对象 Mdl
使用最佳估计可行点,即基于最终高斯过程模型在最终迭代中产生 BestSoFar(estim.)
值的超参数集。
您可以从 HyperparameterOptimizationResults
属性或使用 bestPoint
函数获得最佳点。
Mdl.HyperparameterOptimizationResults.XAtMinEstimatedObjective
ans=1×2 table
BoxConstraint KernelScale
_____________ ___________
985.37 0.27389
[x,CriterionValue,iteration] = bestPoint(Mdl.HyperparameterOptimizationResults)
x=1×2 table
BoxConstraint KernelScale
_____________ ___________
985.37 0.27389
CriterionValue = 0.0888
iteration = 19
默认情况下,bestPoint
函数使用 'min-visited-upper-confidence-interval'
条件。此条件选择从第 19 次迭代获得的超参数作为最佳点。CriterionValue
是最终高斯过程模型计算的交叉验证损失的上界。使用分区 c
计算实际交叉验证损失。
L_MinEstimated = kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,'KernelFunction','rbf', ... 'BoxConstraint',x.BoxConstraint,'KernelScale',x.KernelScale))
L_MinEstimated = 0.0700
实际交叉验证损失接近估计值。Estimated objective function value
显示在优化结果图的下方。
您也可以从 HyperparameterOptimizationResults
属性或通过将 Criterion
指定为 'min-observed'
来提取最佳观测可行点(即迭代输出中的最后一个 Best
点)。
Mdl.HyperparameterOptimizationResults.XAtMinObjective
ans=1×2 table
BoxConstraint KernelScale
_____________ ___________
953.22 0.26253
[x_observed,CriterionValue_observed,iteration_observed] = bestPoint(Mdl.HyperparameterOptimizationResults,'Criterion','min-observed')
x_observed=1×2 table
BoxConstraint KernelScale
_____________ ___________
953.22 0.26253
CriterionValue_observed = 0.0650
iteration_observed = 16
'min-observed'
条件选择从第 16 次迭代获得的超参数作为最佳点。CriterionValue_observed
是使用所选超参数计算的实际交叉验证损失。有关详细信息,请参阅 bestPoint
的 Criterion 名称-值参数。
可视化经过优化的分类器。
d = 0.02; [x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)), ... min(cdata(:,2)):d:max(cdata(:,2))); xGrid = [x1Grid(:),x2Grid(:)]; [~,scores] = predict(Mdl,xGrid); figure h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*'); hold on h(3) = plot(cdata(Mdl.IsSupportVector,1), ... cdata(Mdl.IsSupportVector,2),'ko'); contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k'); legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');
基于新数据计算准确度
生成并分类新的测试数据点。
grnobj = gmdistribution(grnpop,.2*eye(2)); redobj = gmdistribution(redpop,.2*eye(2)); newData = random(grnobj,10); newData = [newData;random(redobj,10)]; grpData = ones(20,1); % green = 1 grpData(11:20) = -1; % red = -1 v = predict(Mdl,newData);
基于测试数据集计算误分类率。
L_Test = loss(Mdl,newData,grpData)
L_Test = 0.3500
确定哪些新数据点是分类正确的。将正确分类的点格式化为红色方块,将不正确分类的点格式化为黑色方块。
h(4:5) = gscatter(newData(:,1),newData(:,2),v,'mc','**'); mydiff = (v == grpData); % Classified correctly for ii = mydiff % Plot red squares around correct pts h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12); end for ii = not(mydiff) % Plot black squares around incorrect pts h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12); end legend(h,{'-1 (training)','+1 (training)','Support Vectors', ... '-1 (classified)','+1 (classified)', ... 'Correctly Classified','Misclassified'}, ... 'Location','Southeast'); hold off