Scikit-learn 中的 One-hot 与虚拟编码

机器算法验证 回归 分类数据 数据转换 scikit-学习 数据预处理
2022-01-15 01:27:40

编码分类变量有两种不同的方法。比如说,一个分类变量有n 个值。One-hot 编码将其转换为n 个变量,而虚拟编码将其转换为n-1 个变量。如果我们有k个分类变量,每个变量都有n 个值。一种热编码以kn个变量结束,而虚拟编码以kn-k 个变量结束。

我听说对于 one-hot 编码,截取会导致共线性问题,这使得模型不健全。有人称之为“虚拟变量陷阱”。

我的问题:

  1. Scikit-learn 的线性回归模型允许用户禁用拦截。所以对于 one-hot 编码,我应该总是设置 fit_intercept=False 吗?对于虚拟编码,fit_intercept 应始终设置为 True?我在网站上没有看到任何“警告”。

  2. 既然 one-hot 编码会产生更多的变量,那么它是否比虚拟编码具有更大的自由度?

3个回答

Scikit-learn 的线性回归模型允许用户禁用拦截。所以对于 one-hot 编码,我应该总是设置 fit_intercept=False 吗?对于虚拟编码,fit_intercept 应始终设置为 True?我在网站上没有看到任何“警告”。

对于具有 one-hot 编码的非正则化线性模型,是的,您需要将截距设置为 false,否则会产生完美的共线性。 sklearn还允许脊收缩惩罚,在这种情况下没有必要,事实上你应该包括截距和所有水平。对于虚拟编码,您应该包含一个截距,除非您已经标准化了所有变量,在这种情况下截距为零。

既然 one-hot 编码会产生更多的变量,那么它是否比虚拟编码具有更大的自由度?

截距是一个额外的自由度,因此在一个明确指定的模型中,它都等于。

对于第二个,如果有k个分类变量怎么办?k 变量在虚拟编码中被删除。自由度还是一样吗?

您无法拟合使用两个分类变量的所有级别的模型,无论是否截距。因为,一旦您对模型中一个变量中的所有级别进行单热编码,例如使用二进制变量,那么您就有一个等于常数向量的预测变量的线性组合x1,x2,,xn

x1+x2++xn=1

如果您然后尝试将另一个分类的所有级别输入到模型中,您最终会得到一个不同的线性组合,该组合等于一个常数向量x

x1+x2++xk=1

所以你创建了一个线性依赖

x1+x2+xnx1x2xk=0

因此,您必须在第二个变量中省略一个级别,并且所有内容都正确排列。

比如说,我有 3 个分类变量,每个变量有 4 个级别。在虚拟编码中,3*4-3=9 个变量由一个截距构建。在 one-hot 编码中,3*4=12 个变量是在没有截距的情况下构建的。我对么?

第二件事实际上不起作用。您创建的列设计矩阵将是单数的。您需要从三个不同的分类编码中删除三列,以恢复设计的非奇异性。3×4=12

在@MatthewDrury 关于这个问题的回答中添加一点:

比如说,我有 3 个分类变量,每个变量有 4 个级别。在虚拟编码中,3*4-3=9 个变量由一个截距构建。在 one-hot 编码中,3*4=12 个变量是在没有截距的情况下构建的。我对么?

model.matrix我们可以通过使用from R来检查设计矩阵在有和没有截距的情况下是什么样子的。

拦截:

> df <- expand.grid(w = letters[1:4], x = letters[5:8], y = letters[9:12])
> model.matrix(~ w + x + y, df)
   (Intercept) wb wc wd xf xg xh yj yk yl
1            1  0  0  0  0  0  0  0  0  0
2            1  1  0  0  0  0  0  0  0  0
3            1  0  1  0  0  0  0  0  0  0
4            1  0  0  1  0  0  0  0  0  0
5            1  0  0  0  1  0  0  0  0  0
6            1  1  0  0  1  0  0  0  0  0
7            1  0  1  0  1  0  0  0  0  0
8            1  0  0  1  1  0  0  0  0  0
9            1  0  0  0  0  1  0  0  0  0
10           1  1  0  0  0  1  0  0  0  0
11           1  0  1  0  0  1  0  0  0  0
12           1  0  0  1  0  1  0  0  0  0
13           1  0  0  0  0  0  1  0  0  0
14           1  1  0  0  0  0  1  0  0  0
15           1  0  1  0  0  0  1  0  0  0
16           1  0  0  1  0  0  1  0  0  0
17           1  0  0  0  0  0  0  1  0  0
18           1  1  0  0  0  0  0  1  0  0
19           1  0  1  0  0  0  0  1  0  0
20           1  0  0  1  0  0  0  1  0  0
21           1  0  0  0  1  0  0  1  0  0
22           1  1  0  0  1  0  0  1  0  0
23           1  0  1  0  1  0  0  1  0  0
24           1  0  0  1  1  0  0  1  0  0
25           1  0  0  0  0  1  0  1  0  0
26           1  1  0  0  0  1  0  1  0  0
27           1  0  1  0  0  1  0  1  0  0
28           1  0  0  1  0  1  0  1  0  0
29           1  0  0  0  0  0  1  1  0  0
30           1  1  0  0  0  0  1  1  0  0
31           1  0  1  0  0  0  1  1  0  0
32           1  0  0  1  0  0  1  1  0  0
33           1  0  0  0  0  0  0  0  1  0
34           1  1  0  0  0  0  0  0  1  0
35           1  0  1  0  0  0  0  0  1  0
36           1  0  0  1  0  0  0  0  1  0
37           1  0  0  0  1  0  0  0  1  0
38           1  1  0  0  1  0  0  0  1  0
39           1  0  1  0  1  0  0  0  1  0
40           1  0  0  1  1  0  0  0  1  0
41           1  0  0  0  0  1  0  0  1  0
42           1  1  0  0  0  1  0  0  1  0
43           1  0  1  0  0  1  0  0  1  0
44           1  0  0  1  0  1  0  0  1  0
45           1  0  0  0  0  0  1  0  1  0
46           1  1  0  0  0  0  1  0  1  0
47           1  0  1  0  0  0  1  0  1  0
48           1  0  0  1  0  0  1  0  1  0
49           1  0  0  0  0  0  0  0  0  1
50           1  1  0  0  0  0  0  0  0  1
51           1  0  1  0  0  0  0  0  0  1
52           1  0  0  1  0  0  0  0  0  1
53           1  0  0  0  1  0  0  0  0  1
54           1  1  0  0  1  0  0  0  0  1
55           1  0  1  0  1  0  0  0  0  1
56           1  0  0  1  1  0  0  0  0  1
57           1  0  0  0  0  1  0  0  0  1
58           1  1  0  0  0  1  0  0  0  1
59           1  0  1  0  0  1  0  0  0  1
60           1  0  0  1  0  1  0  0  0  1
61           1  0  0  0  0  0  1  0  0  1
62           1  1  0  0  0  0  1  0  0  1
63           1  0  1  0  0  0  1  0  0  1
64           1  0  0  1  0  0  1  0  0  1

没有拦截:

> model.matrix(~ w + x + y - 1, df)
   wa wb wc wd xf xg xh yj yk yl
1   1  0  0  0  0  0  0  0  0  0
2   0  1  0  0  0  0  0  0  0  0
3   0  0  1  0  0  0  0  0  0  0
4   0  0  0  1  0  0  0  0  0  0
5   1  0  0  0  1  0  0  0  0  0
6   0  1  0  0  1  0  0  0  0  0
7   0  0  1  0  1  0  0  0  0  0
8   0  0  0  1  1  0  0  0  0  0
9   1  0  0  0  0  1  0  0  0  0
10  0  1  0  0  0  1  0  0  0  0
11  0  0  1  0  0  1  0  0  0  0
12  0  0  0  1  0  1  0  0  0  0
13  1  0  0  0  0  0  1  0  0  0
14  0  1  0  0  0  0  1  0  0  0
15  0  0  1  0  0  0  1  0  0  0
16  0  0  0  1  0  0  1  0  0  0
17  1  0  0  0  0  0  0  1  0  0
18  0  1  0  0  0  0  0  1  0  0
19  0  0  1  0  0  0  0  1  0  0
20  0  0  0  1  0  0  0  1  0  0
21  1  0  0  0  1  0  0  1  0  0
22  0  1  0  0  1  0  0  1  0  0
23  0  0  1  0  1  0  0  1  0  0
24  0  0  0  1  1  0  0  1  0  0
25  1  0  0  0  0  1  0  1  0  0
26  0  1  0  0  0  1  0  1  0  0
27  0  0  1  0  0  1  0  1  0  0
28  0  0  0  1  0  1  0  1  0  0
29  1  0  0  0  0  0  1  1  0  0
30  0  1  0  0  0  0  1  1  0  0
31  0  0  1  0  0  0  1  1  0  0
32  0  0  0  1  0  0  1  1  0  0
33  1  0  0  0  0  0  0  0  1  0
34  0  1  0  0  0  0  0  0  1  0
35  0  0  1  0  0  0  0  0  1  0
36  0  0  0  1  0  0  0  0  1  0
37  1  0  0  0  1  0  0  0  1  0
38  0  1  0  0  1  0  0  0  1  0
39  0  0  1  0  1  0  0  0  1  0
40  0  0  0  1  1  0  0  0  1  0
41  1  0  0  0  0  1  0  0  1  0
42  0  1  0  0  0  1  0  0  1  0
43  0  0  1  0  0  1  0  0  1  0
44  0  0  0  1  0  1  0  0  1  0
45  1  0  0  0  0  0  1  0  1  0
46  0  1  0  0  0  0  1  0  1  0
47  0  0  1  0  0  0  1  0  1  0
48  0  0  0  1  0  0  1  0  1  0
49  1  0  0  0  0  0  0  0  0  1
50  0  1  0  0  0  0  0  0  0  1
51  0  0  1  0  0  0  0  0  0  1
52  0  0  0  1  0  0  0  0  0  1
53  1  0  0  0  1  0  0  0  0  1
54  0  1  0  0  1  0  0  0  0  1
55  0  0  1  0  1  0  0  0  0  1
56  0  0  0  1  1  0  0  0  0  1
57  1  0  0  0  0  1  0  0  0  1
58  0  1  0  0  0  1  0  0  0  1
59  0  0  1  0  0  1  0  0  0  1
60  0  0  0  1  0  1  0  0  0  1
61  1  0  0  0  0  0  1  0  0  1
62  0  1  0  0  0  0  1  0  0  1
63  0  0  1  0  0  0  1  0  0  1
64  0  0  0  1  0  0  1  0  0  1

我们可以看到,当我们使用截距时,对每个变量、 、model.matrix使用虚拟编码,并变成 3 个虚拟变量,加上一个截距列。所以一共有10个自由度。wxy

当我们不使用截距时,model.matrix为 and 创建 4 个虚拟变量w和 3 个虚拟变量(x并且y没有截距列)。所以自由度的数量仍然是 10。

我完全同意@Matthew Drury 和@Cameron Bieganek 对完美共线性和自由度的分析。

但是,我想在这里争辩说,如果我们使用诸如梯度下降之类的方法作为我们的优化器,我们不需要避免完美的共线性

我们可能想要避免完美共线性的原因是当我们使用线性回归并且我们的损失函数是MSE时,我们可以解决封闭形式的解决方案,这涉及到关于的矩阵的逆,而完美的共线性会使这个矩阵非可逆的。然而,在实践中,由于逆矩阵的计算非常昂贵,我们可以使用其他更快的方法,例如梯度下降来计算近似解,这个过程不涉及矩阵的逆. 因此,这里可以容忍完美的共线性。(也许这就是他们没有警告的原因)所以,我们可以使用:XO(n3)

  1. 一个热拦截或
  2. 一个没有拦截的热或
  3. 带拦截的假人。

他们会产生非常相似的结果。

我使用带有 sklearn 的 mpg dateset 运行回归,将一个 hot 或 dummy 应用于具有三个类别的“origin”特征,结果非常相似:(请注意残差和 coef 彼此相似,我枚举类别词前面用红色矩形标记的参数,其他是一些连续特征的参数。详细代码可以看这里,抱歉没有注释,大部分代码来自tensorflow教程。) 一个热回归结果 虚拟回归结果

三个结果之间的关系:

  1. 顺便说一句,我们还可以观察到,对于一热没有截距的情况,最后三个一热特征前面的参数实际上等于截距的基本和,最后三个一热特征前面的参数一热截距,可以用完美共线性来解释

  2. 我们还可以注意到,在 dummy with intercept 情况下,intercept 实际上是 one-hot without intercept 情况下第二个分类编码项的参数(-13.71778...)。而虚拟情况下分类编码项的两个参数是对应参数与第二项参数之间的差异,这与计量经济学中对分类项之前的参数的解释是一致的:其他每个类别的差异有多大将输出与基本类别进行比较。