class_weight 在决策树中如何工作

数据挖掘 scikit-学习 决策树 阶级失衡
2021-09-30 11:41:54

scikit-learn实现DecisionTreeClassifier有一个参数 as class_weight根据文档:

与 {class_label: weight} 形式的类关联的权重。如果没有给出,所有的类都应该有一个权重。

“平衡”模式使用 y 的值自动调整权重,与输入数据中的类频率成反比,如 n_samples / (n_classes * np.bincount(y))

我的理解是它应该用于不平衡类的情况。

问题: DT(分类)算法在确定给定节点的理想分割时如何使用此参数?它是否考虑了预测空间中某个区域的某种“加权”多数类而不是简单多数类?

1个回答

当决定一个节点的分裂时,该算法基本上为给定节点和分裂后的两个结果左右节点计算一些度量(熵或基尼杂质)。比较它们可以告诉您拆分会在多大程度上改善结果。

子节点的统计信息分别由左右节点中的样本数加权。

当您使用sample_weight它时,它会调整计数并将其替换为样本权重的总和。class_weight根据其类别比例,基于其类别为每个样本赋予相等的样本权重。

例如,杂质的改善计算为

NparentNtotal(impurityparentNrightNparentimpurityrightchildNleftNparentimpurityleftchild)

没有class_weightor sample_weights,则Ns 只是计数。class_weight相应的权重替换它们。

熵的想法是相同的,尽管计算方式不同。

源代码