短篇故事:
我有一个由一些特征提取器和一个 LDA 分类器组成的分类管道。在交叉验证中评估管道时,我得到了 94% 的不错的测试准确率(对于 19 个类)。然而,在评估不同数量的训练数据的测试准确度时,我得到了奇怪的结果:测试准确度总体上随着训练数据的增加而增加(这是预期的),但对于一定数量的训练数据,准确度会崩溃(见图):
该图显示了管道的测试准确度(y 轴)与用于训练的每个类的样本数(x 轴)。对于 19 个类中的每一类,有 50 个训练样本。对所有未用于训练的样本进行了测试。
10-11 的负峰值对我来说完全没有意义。谁能解释一下?
长篇大论:
为了确保这不是随机效应,我针对训练样本的不同随机排列运行了 40 次测试。这是图中显示的方差。此外,我仅使用特征的子集重复了测试。结果看起来一样,只是负峰的位置不同,见下图:
我注意到峰值的位置随着特征的数量而变化,所以我再次对 11 个不同数量的特征进行了测试。结果如下图: (x轴与上面相同。y轴是所选特征集的大小/16,因此例如5.0表示80个特征颜色是测试精度)
这表明峰的位置与特征数量之间存在直接关系。
进一步的几点:
- 我只有 19 个课程有很多功能(176 个)。当然这不是最佳的,但我认为它并不能解释那个峰值。
- 一些变量是共线的(实际上很多是共线的)。同样,不完美但不能解释峰值。
编辑2:
这是一些数据。它使用以下设置计算:
- K = 19(班级数)
- p = 192(特征数)
- n = 2 到 19(每类的训练样本数,训练/测试拆分是分层随机拆分)
输出是:
- n:每类样本数
- 测试数据:输入 lda 的测试数据的形状(样本数 x 特征数)
- 训练数据:输入 lda 的训练数据的形状(样本数 x 特征数)
- 测试和训练的准确性
- rank:通过在https://github.com/scikit-learn/scikit-learn/blob/a95203b/sklearn/lda.py#L369
rank
中定义的打印获得
->
n test data train data test acc train acc rank
2 (912, 192) (38, 192) 0.626 0.947 19
3 (893, 192) (57, 192) 0.775 1.000 38
4 (874, 192) (76, 192) 0.783 1.000 57
5 (855, 192) (95, 192) 0.752 1.000 76
6 (836, 192) (114, 192) 0.811 1.000 95
7 (817, 192) (133, 192) 0.760 1.000 114
8 (798, 192) (152, 192) 0.786 1.000 133
9 (779, 192) (171, 192) 0.730 1.000 152
10 (760, 192) (190, 192) 0.532 1.000 171
11 (741, 192) (209, 192) 0.702 1.000 176
12 (722, 192) (228, 192) 0.727 1.000 176
13 (703, 192) (247, 192) 0.856 1.000 176
14 (684, 192) (266, 192) 0.857 1.000 176
15 (665, 192) (285, 192) 0.887 1.000 176
16 (646, 192) (304, 192) 0.881 1.000 176
17 (627, 192) (323, 192) 0.896 1.000 176
18 (608, 192) (342, 192) 0.913 1.000 176
19 (589, 192) (361, 192) 0.900 1.000 176
20 (570, 192) (380, 192) 0.916 1.000 176
21 (551, 192) (399, 192) 0.907 1.000 176
22 (532, 192) (418, 192) 0.929 0.995 176
23 (513, 192) (437, 192) 0.916 0.995 176
24 (494, 192) (456, 192) 0.909 0.991 176
25 (475, 192) (475, 192) 0.947 0.992 176
26 (456, 192) (494, 192) 0.928 0.992 176
27 (437, 192) (513, 192) 0.927 0.992 176
28 (418, 192) (532, 192) 0.940 0.992 176
29 (399, 192) (551, 192) 0.952 0.991 176
30 (380, 192) (570, 192) 0.934 0.989 176
31 (361, 192) (589, 192) 0.922 0.992 176
32 (342, 192) (608, 192) 0.930 0.990 176
33 (323, 192) (627, 192) 0.929 0.989 176
34 (304, 192) (646, 192) 0.947 0.986 176
35 (285, 192) (665, 192) 0.940 0.986 176
36 (266, 192) (684, 192) 0.940 0.993 176
37 (247, 192) (703, 192) 0.935 0.989 176
38 (228, 192) (722, 192) 0.939 0.985 176
39 (209, 192) (741, 192) 0.923 0.988 176