有没有办法为 rpart 模型中的每个节点获取 gini 索引值?

数据挖掘 r 决策树
2022-02-06 19:01:27
df <- tibble(x=factor(c("A", "B")), y=factor(c(1, 0)))
model <- rpart(formula=y~., data=df, method="class", control=rpart.control(minsplit=2))

这里模型将有 1 个父节点和两个子节点。如何从 rpart 模型对象中获取这些节点的基尼指数值?

2个回答

基尼杂质可以计算为1p12p22对于每个节点。例如,如果节点 1 包含 40% 的“1”和 60% 的“0” gini = 1 - 0.4^2 - 0.6^2,. 节点大小n、'0'个数的信息dev存储在model$frame. 每个节点的基尼系数可以用节点大小n和“0”dev的数量来计算model$frame

frame <- model$frame
frame[['gini']] = 1 - (frame[['dev']] / frame[['n']])^2 - (1 - frame[['dev']] / frame[['n']])^2

frame[,c('var','n','dev','gini')]
>      var  n dev      gini
> 1     x3 10   5 0.5000000
> 2 <leaf>  4   1 0.3750000
> 3 <leaf>  6   2 0.4444444

每次拆分的 Gini 改进是通过父节点和子节点之间的加权差来计算的。

frame[['improve']] = NA
for (i in 1:nrow(frame)) {
  if (frame[i,'var'] == '<leaf>') next

  ind = which(rownames(frame) %in% (as.numeric(rownames(frame)[i])*2+c(0,1)))
  frame[i,'improve'] = frame[i,'n']*frame[i,'gini'] - frame[ind[1],'n']*frame[ind[1],'gini'] - frame[ind[2],'n']*frame[ind[2],'gini']
}

frame[,c('var','n','dev','gini','improve')]
>      var  n dev      gini   improve
> 1     x3 10   5 0.5000000 0.8333333
> 2 <leaf>  4   1 0.3750000        NA
> 3 <leaf>  6   2 0.4444444        NA

#comparing with
model$splits
>    count ncat   improve index  adj
> x3    10    2 0.8333333     1 0.00
> x2    10    2 0.2380952     2 0.00
> x2     0    2 0.7000000     3 0.25

以下代码应计算rpart具有任意数量类的分类树的基尼指数:

gini  <- function(tree){
  # calculate gini index for `rpart` tree
  ylevels <- attributes(tree)[["ylevels"]]
  nclass <- length(ylevels)
  yval2 <- tree[["frame"]][["yval2"]]
  vars <- tree[["frame"]][["var"]]
  labls = labels(tree)
  df = data.frame(matrix(nrow=length(labls), ncol=5))
  colnames(df) <- c("Name", "GiniIndex", "Class", "Items", "ItemProbs")
  
  for(i in 1:length(vars)){
    row <- yval2[i , ]
    node.class <- row[1]
    j <- 2
    node.class_counts = row[j:(j+nclass-1)]
    j <- j+nclass
    node.class_probs = row[j:(j+nclass-1)]
    
    gini = 1-sum(node.class_probs^2)
    gini = round(gini,5)
    name = paste(vars[i], " (", labls[i], ")")
    df[i,] = c(name, gini, node.class, toString(round(node.class_counts,5)), toString(round(node.class_probs,5)))
  }
  return(df)
}


> df <- data.frame(x=factor(c("A", "B", "C", "C", "D")), y=factor(c(1, 2, 3, 3, 4)))
> model <- rpart(formula=y~., data=df, method="class", control=rpart.control(minsplit=2))
> gini(model)
             Name GiniIndex Class      Items                    ItemProbs
1     x  ( root )      0.72     3 1, 1, 2, 1           0.2, 0.2, 0.4, 0.2
2    x  ( x=abd )   0.66667     1 1, 1, 0, 1 0.33333, 0.33333, 0, 0.33333
3 <leaf>  ( x=a )         0     1 1, 0, 0, 0                   1, 0, 0, 0
4     x  ( x=bd )       0.5     2 0, 1, 0, 1               0, 0.5, 0, 0.5
5 <leaf>  ( x=b )         0     2 0, 1, 0, 0                   0, 1, 0, 0
6 <leaf>  ( x=d )         0     4 0, 0, 0, 1                   0, 0, 0, 1
7 <leaf>  ( x=c )         0     3 0, 0, 2, 0                   0, 0, 1, 0


# don't know how to publish plots on StackExchange:
# rpart.plot(model, extra=104, box.palette="Blues", fallen.leaves=FALSE)