评估回归神经网络性能
使用 fitrnet 创建一个具有全连接层的前馈回归神经网络模型。针对训练过程的早停使用验证数据,以防止模型过拟合。然后,使用该模型的对象函数来评估它在测试数据上的性能。
加载样本数据
加载 carbig 数据集,该数据集包含 20 世纪 70 年代和 80 年代初生产的汽车的测量值。
load carbig将 Origin 变量转换为分类变量。然后创建一个包含预测变量 Acceleration、Displacement 等以及响应变量 MPG 的表。每行包含单辆汽车的测量值。删除表中具有缺失值的行。
Origin = categorical(cellstr(Origin));
Tbl = table(Acceleration,Displacement,Horsepower, ...
Model_Year,Origin,Weight,MPG);
Tbl = rmmissing(Tbl);对数据进行分区
将数据分成训练集、验证集和测试集。首先,保留大约三分之一的观测值用于测试集。然后,将其余数据分成两半分别用于创建训练集和验证集。
rng("default") % For reproducibility of the data partitions cvp1 = cvpartition(size(Tbl,1),"Holdout",1/3); testTbl = Tbl(test(cvp1),:); remainingTbl = Tbl(training(cvp1),:); cvp2 = cvpartition(size(remainingTbl,1),"Holdout",1/2); validationTbl = remainingTbl(test(cvp2),:); trainTbl = remainingTbl(training(cvp2),:);
训练神经网络
使用训练集训练回归神经网络模型。指定 tblTrain 的 MPG 列作为响应变量,并对数值预测变量进行标准化。使用验证集在每次迭代时评估模型。通过使用 Verbose 名称-值参量来指定在每次迭代时显示训练信息。默认情况下,如果验证损失连续六次大于或等于迄今为止计算的最小验证损失,则训练过程会提前结束。要更改允许验证损失大于或等于最小值的次数,请指定 ValidationPatience 名称-值参量。
Mdl = fitrnet(trainTbl,"MPG","Standardize",true, ... "ValidationData",validationTbl, ... "Verbose",1);
|==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 1| 102.962345| 46.853164| 6.700877| 0.032779| 115.730384| 0| | 2| 55.403995| 22.171181| 1.811805| 0.020571| 53.086379| 0| | 3| 37.588848| 11.135231| 0.782861| 0.005298| 38.580002| 0| | 4| 29.713458| 8.379231| 0.392009| 0.003921| 31.021379| 0| | 5| 17.523851| 9.958164| 2.137584| 0.003729| 17.594863| 0| | 6| 12.700624| 2.957771| 0.744551| 0.003962| 14.209019| 0| | 7| 11.841152| 1.907378| 0.201770| 0.003880| 13.159899| 0| | 8| 10.162988| 2.542555| 0.576907| 0.003956| 11.352490| 0| | 9| 8.889095| 2.779980| 0.615716| 0.002668| 10.446334| 0| | 10| 7.670335| 2.400272| 0.648711| 0.011382| 10.424337| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 11| 7.416274| 0.505111| 0.214707| 0.005407| 10.522517| 1| | 12| 7.338923| 0.880655| 0.119085| 0.004414| 10.648031| 2| | 13| 7.149407| 1.784821| 0.277908| 0.002899| 10.800952| 3| | 14| 6.866385| 1.904480| 0.472190| 0.005637| 10.839202| 4| | 15| 6.815575| 3.339285| 0.943063| 0.002956| 10.031692| 0| | 16| 6.428137| 0.684771| 0.133729| 0.003287| 9.867819| 0| | 17| 6.363299| 0.456606| 0.125363| 0.006535| 9.720076| 0| | 18| 6.289887| 0.742923| 0.152290| 0.009971| 9.576588| 0| | 19| 6.215407| 0.964684| 0.183503| 0.002971| 9.422910| 0| | 20| 6.078333| 2.124971| 0.566948| 0.002843| 9.599573| 1| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 21| 5.947923| 1.217291| 0.583867| 0.003745| 9.618400| 2| | 22| 5.855505| 0.671774| 0.285123| 0.002729| 9.734680| 3| | 23| 5.831802| 1.882061| 0.657368| 0.001770| 10.365968| 4| | 24| 5.713261| 1.004072| 0.134719| 0.001882| 10.314258| 5| | 25| 5.520766| 0.967032| 0.290156| 0.001891| 10.177322| 6| |==========================================================================================|
使用 Mdl 对象的 TrainingHistory 属性内部的信息来检查对应于最小验证均方误差 (MSE) 的迭代。最终返回的模型 Mdl 是在该迭代中训练的模型。
iteration = Mdl.TrainingHistory.Iteration; valLosses = Mdl.TrainingHistory.ValidationLoss; [~,minIdx] = min(valLosses); iteration(minIdx)
ans = 19
评估测试集性能
通过使用 loss 和 predict 对象函数,评估经过训练的模型 Mdl 在测试集 testTbl 上的性能。
计算测试集均方误差 (MSE)。MSE 值越小,表示性能越好。
mse = loss(Mdl,testTbl,"MPG")mse = 7.4101
比较预测的测试集响应值与真实响应值。沿垂直轴绘制预测的每加仑英里数 (MPG),沿水平轴绘制真实的 MPG。参考线上的点表示正确的预测值。好的模型产生的预测值会散布在参考线附近。
predictedY = predict(Mdl,testTbl); plot(testTbl.MPG,predictedY,".") hold on plot(testTbl.MPG,testTbl.MPG) hold off xlabel("True Miles Per Gallon (MPG)") ylabel("Predicted Miles Per Gallon (MPG)")

使用箱线图按原产国/地区比较预测的 MPG 值和真实 MPG 值的分布。使用 boxchart 函数创建箱线图。每个箱线图会显示中位数、下四分位数和上四分位数、任何离群值(使用四分位差计算),以及不是离群值的最小值和最大值。特别是,每个箱内的线表示样本中位数,圆形标记表示离群值。
对于每个原产国/地区,比较红色箱线图(显示预测的 MPG 值分布)与蓝色箱线图(显示真实的 MPG 值分布)。预测的 MPG 值和真实的 MPG 值的相似分布表示好的预测值。
boxchart(testTbl.Origin,testTbl.MPG) hold on boxchart(testTbl.Origin,predictedY) hold off legend(["True MPG","Predicted MPG"]) xlabel("Country of Origin") ylabel("Miles Per Gallon (MPG)")

对于大多数国家/地区,预测的 MPG 值和真实的 MPG 值都具有相似的分布值。有些差异可能是由于训练集和测试集中汽车数量较少。
比较训练集和测试集中汽车的 MPG 值范围。
trainSummary = grpstats(trainTbl(:,["MPG","Origin"]),"Origin", ... "range")
trainSummary=6×3 table
Origin GroupCount range_MPG
_______ __________ _________
France France 2 1.2
Germany Germany 12 23.4
Italy Italy 1 0
Japan Japan 26 26.6
Sweden Sweden 4 8
USA USA 86 27
testSummary = grpstats(testTbl(:,["MPG","Origin"]),"Origin", ... "range")
testSummary=6×3 table
Origin GroupCount range_MPG
_______ __________ _________
France France 4 19.8
Germany Germany 13 20.3
Italy Italy 4 11.3
Japan Japan 26 25.6
Sweden Sweden 1 0
USA USA 82 29
对于像法国、意大利和瑞典这样在训练集和测试集中的汽车数量很少的国家,MPG 值的范围在训练集和测试集中的变化很大。
绘制测试集残差值。好的模型通常具有大致对称地散布在 0 周围的残差值。残差中的明显模式意味着您可能需要改进您的模型。
residuals = testTbl.MPG - predictedY; plot(testTbl.MPG,residuals,".") hold on yline(0) hold off xlabel("True Miles Per Gallon (MPG)") ylabel("MPG Residuals")

该图显示残差分布良好。
您可以获取关于具有最大残差值(以绝对值计)的观测值的详细信息。
[~,residualIdx] = sort(residuals,"descend", ... "ComparisonMethod","abs"); residuals(residualIdx)
ans = 130×1
-8.8469
8.4427
8.0493
7.8996
-6.2220
5.8589
5.7007
-5.6733
-5.4545
5.1899
-4.9175
-4.8600
4.5415
-4.3959
-4.3915
⋮
显示具有最大残差值的三个观测值,即幅值大于 8 的观测值。
testTbl(residualIdx(1:3),:)
ans=3×7 table
Acceleration Displacement Horsepower Model_Year Origin Weight MPG
____________ ____________ __________ __________ ______ ______ ____
17.6 91 68 82 Japan 1970 31
11.4 168 132 80 Japan 2910 32.7
13.8 91 67 80 Japan 1850 44.6
另请参阅
fitrnet | loss | predict | RegressionNeuralNetwork | boxchart