CART:当杂质减少的增益相等时,选择用于分裂的最佳预测因子?

机器算法验证 r 机器学习 分类 数据挖掘 大车
2022-03-01 17:26:53

我的问题涉及分类树。考虑 Iris 数据集中的以下示例:

在此处输入图像描述

我想为第一次拆分手动选择最佳预测器。根据 CART 算法,进行拆分的最佳特征是最大化减少分区杂质的特征,也称为基尼增益:

GiniGain(N,X)=Gini(N)|N1||N|Gini(N1)|N2||N|Gini(N1)

其中是给定特征,是要在其上进行拆分的节点,创建的两个子节点是节点中元素的数量。XNN1N2N|.|

,其中是节点中的类别数Gini(N)=1k=1Kpk2K

现在,由于根据花瓣宽度(轴#1)和花瓣长度(轴#2)进行分割会产生相同的分区(所有 Setosa 花都与非 Setosa 分开),每个的 GinGain 分数将完全相同预测器。那么 CART 算法如何决定哪一个是最好的呢?

直观地,可以看到花瓣长度(2) 的分割与最大的“边距”相关,因此应该选择花瓣长度(实际上是rpart在 R 中实现时发生的情况),但中没有任何东西测量边距,所以决定必须基于别的东西。GiniGain

相关线程,但没有我的问题的答案。

相关线程没有任何答案。

1个回答

我承认自己是一个平庸的 c 代码解释器,而且这个旧代码不是用户友好的。也就是说,我浏览了源代码并进行了这些观察,这让我非常肯定地说:“rpart 从字面上选择了第一个也是最好的变量列”。由于第 1 列和第 2 列产生较差的拆分,因此petal.length 将是第一个拆分变量,因为此列在 data.frame/matrix 中的 petal.width 之前。最后,我通过反转列顺序来展示这一点,这样花瓣.with 将是第一个拆分变量。

rpart 源代码的 c 源文件“bsplit.c”中,我从第 38 行引用:

 * test out the variables 1 at at time
me->primary = (pSplit) NULL;
for (i = 0; i < rp.nvar; i++) {

...因此在从 i=1 到 rp.nvar 的 for 循环中迭代,将调用一个损失函数来扫描由一个变量分割的所有内容,在 gini.c 中查找“非分类分割”第 230 行,找到的最佳分割是如果新拆分更好,则更新。(这也可以是用户定义的损失函数)

if (temp < best) {
        best = temp;
        where = i;
        direction = lmean < rmean ? LEFT : RIGHT;
}

最后一行 323,计算了通过变量进行最佳分割的改进......

*improve = total_ss - best

...回到 bsplit.c 中,检查改进是否大于之前看到的,并且仅在更大时更新。

if (improve > rp.iscale)
rp.iscale = improve;        /* largest seen so far */

我对此的印象是,将选择第一个和最好的(可能的平局),因为只有新断点的得分更高,它才会被保存。这涉及找到的第一个最佳断点和找到的第一个最佳变量。在 gini.c 中似乎不是从左到右简单地扫描断点,因此第一个发现的绑定断点可能很难预测。但是变量从第一列扫描到最后一列是非常可预测的。

此行为与在 classTree.c 中使用以下解决方案的randomForest 实现不同:

/* Break ties at random: */
if (crit == critmax) {
    if (unif_rand() < 1.0 / ntie) {
        *bestSplit = j;
        critmax = crit;
        *splitVar = mvar;
    }
    ntie++;
}

最后,我通过翻转 iris 的列来确认这种行为,这样首先选择petal.width

library(rpart)
data(iris)
iris = iris[,5:1]  #flip/flop", invert order of columns columns
obj = rpart(Species~.,data=iris)
print(obj) #now petal width is first split 


1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)  
  2) Petal.Width< 0.8 50   0 setosa (1.00000000 0.00000000 0.00000000) *
  3) Petal.Width>=0.8 100  50 versicolor (0.00000000 0.50000000 0.50000000)  
    6) Petal.Width< 1.75 54   5 versicolor (0.00000000 0.90740741 0.09259259) *
    7) Petal.Width>=1.75 46   1 virginica (0.00000000 0.02173913 0.97826087) *

...然后再次翻转

iris = iris[,5:1]  #flop/flip", revert order of columns columns
obj = rpart(Species~.,data=iris)
print(obj) #now petal length is first split 
1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)  
  2) Petal.Length< 2.45 50   0 setosa (1.00000000 0.00000000 0.00000000) *
  3) Petal.Length>=2.45 100  50 versicolor (0.00000000 0.50000000 0.50000000)  
    6) Petal.Width< 1.75 54   5 versicolor (0.00000000 0.90740741 0.09259259) *
    7) Petal.Width>=1.75 46   1 virginica (0.00000000 0.02173913 0.97826087) *