df <- tibble(x=factor(c("A", "B")), y=factor(c(1, 0)))
model <- rpart(formula=y~., data=df, method="class", control=rpart.control(minsplit=2))
这里模型将有 1 个父节点和两个子节点。如何从 rpart 模型对象中获取这些节点的基尼指数值?
df <- tibble(x=factor(c("A", "B")), y=factor(c(1, 0)))
model <- rpart(formula=y~., data=df, method="class", control=rpart.control(minsplit=2))
这里模型将有 1 个父节点和两个子节点。如何从 rpart 模型对象中获取这些节点的基尼指数值?
基尼杂质可以计算为对于每个节点。例如,如果节点 1 包含 40% 的“1”和 60% 的“0” gini = 1 - 0.4^2 - 0.6^2
,. 节点大小n
、'0'个数的信息dev
存储在model$frame
. 每个节点的基尼系数可以用节点大小n
和“0”dev
的数量来计算model$frame
:
frame <- model$frame
frame[['gini']] = 1 - (frame[['dev']] / frame[['n']])^2 - (1 - frame[['dev']] / frame[['n']])^2
frame[,c('var','n','dev','gini')]
> var n dev gini
> 1 x3 10 5 0.5000000
> 2 <leaf> 4 1 0.3750000
> 3 <leaf> 6 2 0.4444444
每次拆分的 Gini 改进是通过父节点和子节点之间的加权差来计算的。
frame[['improve']] = NA
for (i in 1:nrow(frame)) {
if (frame[i,'var'] == '<leaf>') next
ind = which(rownames(frame) %in% (as.numeric(rownames(frame)[i])*2+c(0,1)))
frame[i,'improve'] = frame[i,'n']*frame[i,'gini'] - frame[ind[1],'n']*frame[ind[1],'gini'] - frame[ind[2],'n']*frame[ind[2],'gini']
}
frame[,c('var','n','dev','gini','improve')]
> var n dev gini improve
> 1 x3 10 5 0.5000000 0.8333333
> 2 <leaf> 4 1 0.3750000 NA
> 3 <leaf> 6 2 0.4444444 NA
#comparing with
model$splits
> count ncat improve index adj
> x3 10 2 0.8333333 1 0.00
> x2 10 2 0.2380952 2 0.00
> x2 0 2 0.7000000 3 0.25
以下代码应计算rpart
具有任意数量类的分类树的基尼指数:
gini <- function(tree){
# calculate gini index for `rpart` tree
ylevels <- attributes(tree)[["ylevels"]]
nclass <- length(ylevels)
yval2 <- tree[["frame"]][["yval2"]]
vars <- tree[["frame"]][["var"]]
labls = labels(tree)
df = data.frame(matrix(nrow=length(labls), ncol=5))
colnames(df) <- c("Name", "GiniIndex", "Class", "Items", "ItemProbs")
for(i in 1:length(vars)){
row <- yval2[i , ]
node.class <- row[1]
j <- 2
node.class_counts = row[j:(j+nclass-1)]
j <- j+nclass
node.class_probs = row[j:(j+nclass-1)]
gini = 1-sum(node.class_probs^2)
gini = round(gini,5)
name = paste(vars[i], " (", labls[i], ")")
df[i,] = c(name, gini, node.class, toString(round(node.class_counts,5)), toString(round(node.class_probs,5)))
}
return(df)
}
> df <- data.frame(x=factor(c("A", "B", "C", "C", "D")), y=factor(c(1, 2, 3, 3, 4)))
> model <- rpart(formula=y~., data=df, method="class", control=rpart.control(minsplit=2))
> gini(model)
Name GiniIndex Class Items ItemProbs
1 x ( root ) 0.72 3 1, 1, 2, 1 0.2, 0.2, 0.4, 0.2
2 x ( x=abd ) 0.66667 1 1, 1, 0, 1 0.33333, 0.33333, 0, 0.33333
3 <leaf> ( x=a ) 0 1 1, 0, 0, 0 1, 0, 0, 0
4 x ( x=bd ) 0.5 2 0, 1, 0, 1 0, 0.5, 0, 0.5
5 <leaf> ( x=b ) 0 2 0, 1, 0, 0 0, 1, 0, 0
6 <leaf> ( x=d ) 0 4 0, 0, 0, 1 0, 0, 0, 1
7 <leaf> ( x=c ) 0 3 0, 0, 2, 0 0, 0, 1, 0
# don't know how to publish plots on StackExchange:
# rpart.plot(model, extra=104, box.palette="Blues", fallen.leaves=FALSE)