R中随机森林分类中一组预测变量的相对重要性

机器算法验证 r 机器学习 分类 随机森林
2022-02-07 19:31:09

我想确定变量集对randomForestR 中的分类模型的相对重要性。该importance函数MeanDecreaseGini为每个单独的预测变量提供了度量——它就像在一组中的每个预测变量中求和一样简单吗?

例如:

# Assumes df has variables a1, a2, b1, b2, and outcome
rf <- randomForest(outcome ~ ., data=df)
importance(rf)
# To determine whether the "a" predictors are more important than the "b"s,
# can I sum the MeanDecreaseGini for a1 and a2 and compare to that of b1+b2?
2个回答

首先,我想澄清重要性指标实际衡量的是什么。

MeanDecreaseGini是基于 Gini 杂质指数的变量重要性度量,用于在训练期间计算拆分。一个常见的误解是变量重要性度量是指用于断言模型性能的基尼系数,它与 AUC 密切相关,但这是错误的。这是 Breiman 和 Cutler 编写的 randomForest 包的解释:

Gini 重要性
每次对变量 m 进行节点拆分时,两个后代节点的 gini 杂质标准都小于父节点。将森林中所有树木的每个单独变量的 gini 减少相加得出一个快速变量重要性,这通常与排列重要性度量非常一致。

基尼杂质指数定义为 G 其中是目标变量中的类数,是该类的比率。

G=i=1ncpi(1pi)
ncpi

对于两类问题,这会产生以下曲线,该曲线对于 50-50 样本最大化,对于同质集最小化: 2 级的基尼杂质

然后将重要性计算为 对涉及相关预测器的森林中的所有拆分进行平均。由于这是一个平均值,它可以很容易地扩展到对组中包含的变量的所有拆分进行平均。

I=GparentGsplit1Gsplit2

仔细观察,我们知道每个变量的重要性是所使用变量的平均条件,并且该组的 meanDecreaseGini 将只是这些重要性的平均值,加权该变量在森林中使用的份额与同一组中的其他变量相比。这是因为塔的属性

E[E[X|Y]]=E[X]

现在,要直接回答您的问题,它并不像简单地总结每个组中的所有重要性以获得组合的 MeanDecreaseGini 那样简单,而是计算加权平均值将为您提供您正在寻找的答案。我们只需要找到每个组内的可变频率。

这是一个从 R 中的随机森林对象中获取这些的简单脚本:

var.share <- function(rf.obj, members) {
  count <- table(rf.obj$forest$bestvar)[-1]
  names(count) <- names(rf.obj$forest$ncat)
  share <- count[members] / sum(count[members])
  return(share)
}

只需将组中变量的名称作为成员参数传入。

我希望这回答了你的问题。如果感兴趣,我可以编写一个函数来直接获取组的重要性。

编辑:
这是一个给定randomForest对象和具有变量名称的向量列表的组重要性的函数。var.share按照前面定义的方式使用。我没有进行任何输入检查,因此您需要确保使用正确的变量名。

group.importance <- function(rf.obj, groups) {
  var.imp <- as.matrix(sapply(groups, function(g) {
    sum(importance(rf.obj, 2)[g, ]*var.share(rf.obj, g))
  }))
  colnames(var.imp) <- "MeanDecreaseGini"
  return(var.imp)
}

使用示例:

library(randomForest)                                                          
data(iris)

rf.obj <- randomForest(Species ~ ., data=iris)

groups <- list(Sepal=c("Sepal.Width", "Sepal.Length"), 
               Petal=c("Petal.Width", "Petal.Length"))

group.importance(rf.obj, groups)

>

      MeanDecreaseGini
Sepal         6.187198
Petal        43.913020

它也适用于重叠组:

overlapping.groups <- list(Sepal=c("Sepal.Width", "Sepal.Length"), 
                           Petal=c("Petal.Width", "Petal.Length"),
                           Width=c("Sepal.Width", "Petal.Width"), 
                           Length=c("Sepal.Length", "Petal.Length"))

group.importance(rf.obj, overlapping.groups)

>

       MeanDecreaseGini
Sepal          6.187198
Petal         43.913020
Width          30.513776
Length        30.386706

上面定义为 G=sum over classes[pi(1−pi)] 的函数实际上是熵,这是评估拆分的另一种方式。子节点和父节点的熵之差就是信息增益。GINI 杂质函数是 G = 1- 类的总和[pi^2]。