主要内容

评估回归神经网络性能

使用 fitrnet 创建一个具有全连接层的前馈回归神经网络模型。针对训练过程的早停使用验证数据,以防止模型过拟合。然后,使用该模型的对象函数来评估它在测试数据上的性能。

加载样本数据

加载 carbig 数据集,该数据集包含 20 世纪 70 年代和 80 年代初生产的汽车的测量值。

load carbig

Origin 变量转换为分类变量。然后创建一个包含预测变量 AccelerationDisplacement 等以及响应变量 MPG 的表。每行包含单辆汽车的测量值。删除表中具有缺失值的行。

Origin = categorical(cellstr(Origin));
Tbl = table(Acceleration,Displacement,Horsepower, ...
    Model_Year,Origin,Weight,MPG);
Tbl = rmmissing(Tbl);

对数据进行分区

将数据分成训练集、验证集和测试集。首先,保留大约三分之一的观测值用于测试集。然后,将其余数据分成两半分别用于创建训练集和验证集。

rng("default") % For reproducibility of the data partitions
cvp1 = cvpartition(size(Tbl,1),"Holdout",1/3);
testTbl = Tbl(test(cvp1),:);
remainingTbl = Tbl(training(cvp1),:);

cvp2 = cvpartition(size(remainingTbl,1),"Holdout",1/2);
validationTbl = remainingTbl(test(cvp2),:);
trainTbl = remainingTbl(training(cvp2),:);

训练神经网络

使用训练集训练回归神经网络模型。指定 tblTrainMPG 列作为响应变量,并对数值预测变量进行标准化。使用验证集在每次迭代时评估模型。通过使用 Verbose 名称-值参量来指定在每次迭代时显示训练信息。默认情况下,如果验证损失连续六次大于或等于迄今为止计算的最小验证损失,则训练过程会提前结束。要更改允许验证损失大于或等于最小值的次数,请指定 ValidationPatience 名称-值参量。

Mdl = fitrnet(trainTbl,"MPG","Standardize",true, ...
    "ValidationData",validationTbl, ...
    "Verbose",1);
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|           1|  102.962345|   46.853164|    6.700877|    0.032779|  115.730384|           0|
|           2|   55.403995|   22.171181|    1.811805|    0.020571|   53.086379|           0|
|           3|   37.588848|   11.135231|    0.782861|    0.005298|   38.580002|           0|
|           4|   29.713458|    8.379231|    0.392009|    0.003921|   31.021379|           0|
|           5|   17.523851|    9.958164|    2.137584|    0.003729|   17.594863|           0|
|           6|   12.700624|    2.957771|    0.744551|    0.003962|   14.209019|           0|
|           7|   11.841152|    1.907378|    0.201770|    0.003880|   13.159899|           0|
|           8|   10.162988|    2.542555|    0.576907|    0.003956|   11.352490|           0|
|           9|    8.889095|    2.779980|    0.615716|    0.002668|   10.446334|           0|
|          10|    7.670335|    2.400272|    0.648711|    0.011382|   10.424337|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          11|    7.416274|    0.505111|    0.214707|    0.005407|   10.522517|           1|
|          12|    7.338923|    0.880655|    0.119085|    0.004414|   10.648031|           2|
|          13|    7.149407|    1.784821|    0.277908|    0.002899|   10.800952|           3|
|          14|    6.866385|    1.904480|    0.472190|    0.005637|   10.839202|           4|
|          15|    6.815575|    3.339285|    0.943063|    0.002956|   10.031692|           0|
|          16|    6.428137|    0.684771|    0.133729|    0.003287|    9.867819|           0|
|          17|    6.363299|    0.456606|    0.125363|    0.006535|    9.720076|           0|
|          18|    6.289887|    0.742923|    0.152290|    0.009971|    9.576588|           0|
|          19|    6.215407|    0.964684|    0.183503|    0.002971|    9.422910|           0|
|          20|    6.078333|    2.124971|    0.566948|    0.002843|    9.599573|           1|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          21|    5.947923|    1.217291|    0.583867|    0.003745|    9.618400|           2|
|          22|    5.855505|    0.671774|    0.285123|    0.002729|    9.734680|           3|
|          23|    5.831802|    1.882061|    0.657368|    0.001770|   10.365968|           4|
|          24|    5.713261|    1.004072|    0.134719|    0.001882|   10.314258|           5|
|          25|    5.520766|    0.967032|    0.290156|    0.001891|   10.177322|           6|
|==========================================================================================|

使用 Mdl 对象的 TrainingHistory 属性内部的信息来检查对应于最小验证均方误差 (MSE) 的迭代。最终返回的模型 Mdl 是在该迭代中训练的模型。

iteration = Mdl.TrainingHistory.Iteration;
valLosses = Mdl.TrainingHistory.ValidationLoss;
[~,minIdx] = min(valLosses);
iteration(minIdx)
ans = 
19

评估测试集性能

通过使用 losspredict 对象函数,评估经过训练的模型 Mdl 在测试集 testTbl 上的性能。

计算测试集均方误差 (MSE)。MSE 值越小,表示性能越好。

mse = loss(Mdl,testTbl,"MPG")
mse = 
7.4101

比较预测的测试集响应值与真实响应值。沿垂直轴绘制预测的每加仑英里数 (MPG),沿水平轴绘制真实的 MPG。参考线上的点表示正确的预测值。好的模型产生的预测值会散布在参考线附近。

predictedY = predict(Mdl,testTbl);

plot(testTbl.MPG,predictedY,".")
hold on
plot(testTbl.MPG,testTbl.MPG)
hold off
xlabel("True Miles Per Gallon (MPG)")
ylabel("Predicted Miles Per Gallon (MPG)")

Figure contains an axes object. The axes object with xlabel True Miles Per Gallon (MPG), ylabel Predicted Miles Per Gallon (MPG) contains 2 objects of type line. One or more of the lines displays its values using only markers

使用箱线图按原产国/地区比较预测的 MPG 值和真实 MPG 值的分布。使用 boxchart 函数创建箱线图。每个箱线图会显示中位数、下四分位数和上四分位数、任何离群值(使用四分位差计算),以及不是离群值的最小值和最大值。特别是,每个箱内的线表示样本中位数,圆形标记表示离群值。

对于每个原产国/地区,比较红色箱线图(显示预测的 MPG 值分布)与蓝色箱线图(显示真实的 MPG 值分布)。预测的 MPG 值和真实的 MPG 值的相似分布表示好的预测值。

boxchart(testTbl.Origin,testTbl.MPG)
hold on
boxchart(testTbl.Origin,predictedY)
hold off
legend(["True MPG","Predicted MPG"])
xlabel("Country of Origin")
ylabel("Miles Per Gallon (MPG)")

Figure contains an axes object. The axes object with xlabel Country of Origin, ylabel Miles Per Gallon (MPG) contains 2 objects of type boxchart. These objects represent True MPG, Predicted MPG.

对于大多数国家/地区,预测的 MPG 值和真实的 MPG 值都具有相似的分布值。有些差异可能是由于训练集和测试集中汽车数量较少。

比较训练集和测试集中汽车的 MPG 值范围。

trainSummary = grpstats(trainTbl(:,["MPG","Origin"]),"Origin", ...
    "range")
trainSummary=6×3 table
               Origin     GroupCount    range_MPG
               _______    __________    _________

    France     France          2           1.2   
    Germany    Germany        12          23.4   
    Italy      Italy           1             0   
    Japan      Japan          26          26.6   
    Sweden     Sweden          4             8   
    USA        USA            86            27   

testSummary = grpstats(testTbl(:,["MPG","Origin"]),"Origin", ...
    "range")
testSummary=6×3 table
               Origin     GroupCount    range_MPG
               _______    __________    _________

    France     France          4          19.8   
    Germany    Germany        13          20.3   
    Italy      Italy           4          11.3   
    Japan      Japan          26          25.6   
    Sweden     Sweden          1             0   
    USA        USA            82            29   

对于像法国、意大利和瑞典这样在训练集和测试集中的汽车数量很少的国家,MPG 值的范围在训练集和测试集中的变化很大。

绘制测试集残差值。好的模型通常具有大致对称地散布在 0 周围的残差值。残差中的明显模式意味着您可能需要改进您的模型。

residuals = testTbl.MPG - predictedY;
plot(testTbl.MPG,residuals,".")
hold on
yline(0)
hold off
xlabel("True Miles Per Gallon (MPG)")
ylabel("MPG Residuals")

Figure contains an axes object. The axes object with xlabel True Miles Per Gallon (MPG), ylabel MPG Residuals contains 2 objects of type line, constantline. One or more of the lines displays its values using only markers

该图显示残差分布良好。

您可以获取关于具有最大残差值(以绝对值计)的观测值的详细信息。

[~,residualIdx] = sort(residuals,"descend", ...
    "ComparisonMethod","abs");
residuals(residualIdx)
ans = 130×1

   -8.8469
    8.4427
    8.0493
    7.8996
   -6.2220
    5.8589
    5.7007
   -5.6733
   -5.4545
    5.1899
   -4.9175
   -4.8600
    4.5415
   -4.3959
   -4.3915
      ⋮

显示具有最大残差值的三个观测值,即幅值大于 8 的观测值。

testTbl(residualIdx(1:3),:)
ans=3×7 table
    Acceleration    Displacement    Horsepower    Model_Year    Origin    Weight    MPG 
    ____________    ____________    __________    __________    ______    ______    ____

        17.6             91             68            82        Japan      1970       31
        11.4            168            132            80        Japan      2910     32.7
        13.8             91             67            80        Japan      1850     44.6

另请参阅

| | | |