为什么我的 LDA 性能是训练数据量的非单调函数?

机器算法验证 主成分分析 scikit-学习 过拟合 判别分析 svd
2022-03-29 00:37:13

短篇故事:

我有一个由一些特征提取器和一个 LDA 分类器组成的分类管道。在交叉验证中评估管道时,我得到了 94% 的不错的测试准确率(对于 19 个类)。然而,在评估不同数量的训练数据的测试准确度时,我得到了奇怪的结果:测试准确度总体上随着训练数据的增加而增加(这是预期的),但对于一定数量的训练数据,准确度会崩溃(见图):

在此处输入图像描述

该图显示了管道的测试准确度(y 轴)与用于训练的每个类的样本数(x 轴)。对于 19 个类中的每一类,有 50 个训练样本。对所有未用于训练的样本进行了测试。

10-11 的负峰值对我来说完全没有意义。谁能解释一下?

长篇大论:

为了确保这不是随机效应,我针对训练样本的不同随机排列运行了 40 次测试。这是图中显示的方差。此外,我仅使用特征的子集重复了测试。结果看起来一样,只是负峰的位置不同,见下图:

在此处输入图像描述 (仅使用 176 个特征中的 110 个)

在此处输入图像描述 (仅使用其他 66 个功能)

我注意到峰值的位置随着特征的数量而变化,所以我再次对 11 个不同数量的特征进行了测试。结果如下图: 在此处输入图像描述 (x轴与上面相同。y轴是所选特征集的大小/16,因此例如5.0表示80个特征颜色是测试精度)

这表明峰的位置与特征数量之间存在直接关系。

进一步的几点:

  • 我只有 19 个课程有很多功能(176 个)。当然这不是最佳的,但我认为它并不能解释那个峰值。
  • 一些变量是共线的(实际上很多是共线的)。同样,不完美但不能解释峰值。

编辑: 这是训练精度图: 在此处输入图像描述

编辑2:

这是一些数据。它使用以下设置计算:

  • K = 19(班级数)
  • p = 192(特征数)
  • n = 2 到 19(每类的训练样本数,训练/测试拆分是分层随机拆分)

输出是:

->

n   test data   train data  test acc    train acc   rank
 2  (912, 192)  (38, 192)        0.626       0.947  19
 3  (893, 192)  (57, 192)        0.775       1.000  38
 4  (874, 192)  (76, 192)        0.783       1.000  57
 5  (855, 192)  (95, 192)        0.752       1.000  76
 6  (836, 192)  (114, 192)       0.811       1.000  95
 7  (817, 192)  (133, 192)       0.760       1.000  114
 8  (798, 192)  (152, 192)       0.786       1.000  133
 9  (779, 192)  (171, 192)       0.730       1.000  152
10  (760, 192)  (190, 192)       0.532       1.000  171
11  (741, 192)  (209, 192)       0.702       1.000  176
12  (722, 192)  (228, 192)       0.727       1.000  176
13  (703, 192)  (247, 192)       0.856       1.000  176
14  (684, 192)  (266, 192)       0.857       1.000  176
15  (665, 192)  (285, 192)       0.887       1.000  176
16  (646, 192)  (304, 192)       0.881       1.000  176
17  (627, 192)  (323, 192)       0.896       1.000  176
18  (608, 192)  (342, 192)       0.913       1.000  176
19  (589, 192)  (361, 192)       0.900       1.000  176
20  (570, 192)  (380, 192)       0.916       1.000  176
21  (551, 192)  (399, 192)       0.907       1.000  176
22  (532, 192)  (418, 192)       0.929       0.995  176
23  (513, 192)  (437, 192)       0.916       0.995  176
24  (494, 192)  (456, 192)       0.909       0.991  176
25  (475, 192)  (475, 192)       0.947       0.992  176
26  (456, 192)  (494, 192)       0.928       0.992  176
27  (437, 192)  (513, 192)       0.927       0.992  176
28  (418, 192)  (532, 192)       0.940       0.992  176
29  (399, 192)  (551, 192)       0.952       0.991  176
30  (380, 192)  (570, 192)       0.934       0.989  176
31  (361, 192)  (589, 192)       0.922       0.992  176
32  (342, 192)  (608, 192)       0.930       0.990  176
33  (323, 192)  (627, 192)       0.929       0.989  176
34  (304, 192)  (646, 192)       0.947       0.986  176
35  (285, 192)  (665, 192)       0.940       0.986  176
36  (266, 192)  (684, 192)       0.940       0.993  176
37  (247, 192)  (703, 192)       0.935       0.989  176
38  (228, 192)  (722, 192)       0.939       0.985  176
39  (209, 192)  (741, 192)       0.923       0.988  176
1个回答

你发现了一个有趣的现象。

LDA 计算依赖于反转类内散布矩阵SW. 通常 LDA 解表示为特征值分解SW1SB,但从scikit-learn不显式计算散布矩阵,而是使用数据矩阵的 SVD 来计算相同的东西这类似于通常通过 SVD 直接计算PCAX无需计算协方差矩阵。什么scikit-learn是计算类间数据的 SVDXB转化为SW1/2(“关于类内协方差变白”)。并计算SW1/2,他们对类内数据进行 SVDXW.

在里面n<p情况的协方差矩阵XW不是满秩,有一些零特征值,不能倒置。在这种情况下发生的scikit情况是,它们仅使用非零奇异值进行反转(github链接)。

换句话说,他们隐式地对类内数据进行 PCA,只保留非零 PC,然后对其进行 LDA。

现在的问题是,我们应该如何期望这会影响过度拟合?让我们考虑与您的问题相同的设置(但向后),当总样本量N从开始减少Np一直到Np. n是每类的样本量,所以N=nK在哪里K是类的数量。

  • 在大样本量的限制下,PCA 步骤没有影响(所有 PC 都使用),过拟合减少到零,样本外(例如交叉验证)性能应该是最好的。

  • 为了Np,协方差矩阵已经满秩(因此 PCA 步骤没有效果),但最小的特征值非常嘈杂,LDA 会严重过拟合。性能几乎可以降到零。

  • 为了N<p, PCA 步骤变得至关重要。仅有的NKPC 非零;所以维度减少到NK. 现在会发生什么取决于一些领先的 ​​PC 是否具有良好的区分能力。他们不必这样做,但他们经常这样做

    • 如果是这样,那么只使用几台领先的 PC 应该可以很好地工作。PCA用作正则化步骤并提高性能。

      当然,这里 PCA 不执行降维,因为保留了所有可用的组件。因此,目前尚不清楚它是否会提高性能,但正如我们所见,至少在这种情况下确实如此。

    • 然而,如果n太小了,一些重要的PC无法估计并被遗漏,那么性能应该再次下降。

我认为我在文献中的任何地方都没有看到过这个讨论,但这是我对这条奇怪曲线的理解:

在此处输入图像描述

请注意,这有点像scikit-learn(使用svd求解器)如何处理N<p情况。

更新:我们可以预测最小值的位置如下。每个类中的类内协方差矩阵最多具有秩n1,因此汇集的类内协方差最多具有秩(n1)K. 最小值应该在达到满秩时出现(并且 PCA 停止产生任何影响),即对于最小的n这样(n1)K>p

nmin=p+KK+1.
这似乎非常适合您的所有数字。