如何从 randomForest::getTree() 实际绘制样本树?

机器算法验证 r 数据可视化 随机森林 大车
2022-01-15 01:34:13

任何人都得到了关于如何实际绘制几个示例树的库或代码建议:

getTree(rfobj, k, labelVar=TRUE)

(是的,我知道您不应该在操作上执行此操作,RF 是一个黑盒等。我想直观地检查一棵树,看看是否有任何变量的行为违反直觉,需要调整/组合/离散化/转换,检查我的编码因子的工作情况等)


没有像样答案的先前问题:

我实际上想绘制一个示例树所以不要和我争论,已经。我不是在问varImpPlot(Variable Importance Plot) or partialPlotorMDSPlot这些其他图,我已经有了这些,但它们不能替代查看样本树。是的,我可以目视检查getTree(...,labelVar=TRUE).

(我想一个plot.rf.tree()贡献会很受欢迎。)

3个回答

第一个(也是最简单的)解决方案:如果您不热衷于坚持使用 Andy Liaw 中实现的经典 RF,randomForest您可以尝试提供原始 RF 的不同实现的派对包算法(使用条件树和基于单位权重平均的聚合方案)。然后,正如R-help 帖子中所报告的那样,您可以绘制树列表中的单个成员。据我所知,它似乎运行顺利。下面是由 生成的一棵树的图cforest(Species ~ ., data=iris, controls=cforest_control(mtry=2, mincriterion=0))

在此处输入图像描述

第二种(几乎同样简单)的解决方案: R 中的大多数基于树的技术(treerpartTWIX等)都提供了类似tree的结构来打印/绘制单个树。这个想法是将 的输出转换randomForest::getTree为这样的 R 对象,即使从统计的角度来看它是无意义的。基本上,从对象访问树结构很容易tree,如下所示。请注意,它会根据任务的类型略有不同——回归与分类——在后一种情况下,它将添加特定于类的概率作为obj$frame(即 a data.frame)的最后一列。

> library(tree)
> tr <- tree(Species ~ ., data=iris)
> tr
node), split, n, deviance, yval, (yprob)
      * denotes terminal node

 1) root 150 329.600 setosa ( 0.33333 0.33333 0.33333 )  
   2) Petal.Length < 2.45 50   0.000 setosa ( 1.00000 0.00000 0.00000 ) *
   3) Petal.Length > 2.45 100 138.600 versicolor ( 0.00000 0.50000 0.50000 )  
     6) Petal.Width < 1.75 54  33.320 versicolor ( 0.00000 0.90741 0.09259 )  
      12) Petal.Length < 4.95 48   9.721 versicolor ( 0.00000 0.97917 0.02083 )  
        24) Sepal.Length < 5.15 5   5.004 versicolor ( 0.00000 0.80000 0.20000 ) *
        25) Sepal.Length > 5.15 43   0.000 versicolor ( 0.00000 1.00000 0.00000 ) *
      13) Petal.Length > 4.95 6   7.638 virginica ( 0.00000 0.33333 0.66667 ) *
     7) Petal.Width > 1.75 46   9.635 virginica ( 0.00000 0.02174 0.97826 )  
      14) Petal.Length < 4.95 6   5.407 virginica ( 0.00000 0.16667 0.83333 ) *
      15) Petal.Length > 4.95 40   0.000 virginica ( 0.00000 0.00000 1.00000 ) *
> tr$frame
            var   n        dev       yval splits.cutleft splits.cutright yprob.setosa yprob.versicolor yprob.virginica
1  Petal.Length 150 329.583687     setosa          <2.45           >2.45   0.33333333       0.33333333      0.33333333
2        <leaf>  50   0.000000     setosa                                  1.00000000       0.00000000      0.00000000
3   Petal.Width 100 138.629436 versicolor          <1.75           >1.75   0.00000000       0.50000000      0.50000000
6  Petal.Length  54  33.317509 versicolor          <4.95           >4.95   0.00000000       0.90740741      0.09259259
12 Sepal.Length  48   9.721422 versicolor          <5.15           >5.15   0.00000000       0.97916667      0.02083333
24       <leaf>   5   5.004024 versicolor                                  0.00000000       0.80000000      0.20000000
25       <leaf>  43   0.000000 versicolor                                  0.00000000       1.00000000      0.00000000
13       <leaf>   6   7.638170  virginica                                  0.00000000       0.33333333      0.66666667
7  Petal.Length  46   9.635384  virginica          <4.95           >4.95   0.00000000       0.02173913      0.97826087
14       <leaf>   6   5.406735  virginica                                  0.00000000       0.16666667      0.83333333
15       <leaf>  40   0.000000  virginica                                  0.00000000       0.00000000      1.00000000

然后,有一些方法可以漂亮地打印和绘制这些对象。关键函数是一个依赖(图形显示)和(计算节点坐标)的通用tree:::plot.tree方法(我放了一个三元组,允许您直接查看 R 中的代码)。这些函数期望树的表示。其他微妙的问题:(1)默认绘图方法中的参数,有助于管理节点之间的垂直距离(意味着它与偏差成正比,意味着它是固定的);(2) 您需要通过调用来补充以将文本标签添加到节点和拆分,在这种情况下,这意味着您还必须查看.:tree:::treepltree:::treecoobj$frametype = c("proportional", "uniform")tree:::plot.treeproportionaluniformplot(tr)text(tr)tree:::text.tree

getTreefrom 方法返回不同randomForest结构,在线帮助中对此进行了说明。典型输出如下所示,终端节点由status代码 (-1) 指示。(同样,输出会根据任务的类型而有所不同,但仅限于statusprediction列。)

> library(randomForest)
> rf <- randomForest(Species ~ ., data=iris)
> getTree(rf, 1, labelVar=TRUE)
   left daughter right daughter    split var split point status prediction
1              2              3 Petal.Length        4.75      1       <NA>
2              4              5 Sepal.Length        5.45      1       <NA>
3              6              7  Sepal.Width        3.15      1       <NA>
4              8              9  Petal.Width        0.80      1       <NA>
5             10             11  Sepal.Width        3.60      1       <NA>
6              0              0         <NA>        0.00     -1  virginica
7             12             13  Petal.Width        1.90      1       <NA>
8              0              0         <NA>        0.00     -1     setosa
9             14             15  Petal.Width        1.55      1       <NA>
10             0              0         <NA>        0.00     -1 versicolor
11             0              0         <NA>        0.00     -1     setosa
12            16             17 Petal.Length        5.40      1       <NA>
13             0              0         <NA>        0.00     -1  virginica
14             0              0         <NA>        0.00     -1 versicolor
15             0              0         <NA>        0.00     -1  virginica
16             0              0         <NA>        0.00     -1 versicolor
17             0              0         <NA>        0.00     -1  virginica

如果您能设法将上面的表格转换为由 生成的表格tree您可能可以自定义满足您的需求,尽管我没有这种方法的示例。特别是,您可能希望摆脱在 RF 中没有意义的偏差、类别概率等的使用。您只需要设置节点坐标和拆分值。你可以使用它,但老实说,我不确定这是正确的方法。tree:::treepltree:::treecotree:::text.treefixInNamespace()

第三个(当然也是聪明的)解决方案:编写一个真正的as.tree辅助函数来缓解上述所有“补丁”。然后,您可以使用 R 的绘图方法,或者可能更好的是Klimt(直接来自 R)来显示单个树。

我迟到了四年,但如果你真的想坚持使用这个randomForest包(并且有一些很好的理由这样做),并且想要真正可视化树,你可以使用reprtree包。

该软件包的文档记录不是很好(您可以在此处找到文档),但一切都非常简单。要安装软件包,请参考repo 中的initialize.R,因此只需运行以下命令:

options(repos='http://cran.rstudio.org')
have.packages <- installed.packages()
cran.packages <- c('devtools','plotrix','randomForest','tree')
to.install <- setdiff(cran.packages, have.packages[,1])
if(length(to.install)>0) install.packages(to.install)

library(devtools)
if(!('reprtree' %in% installed.packages())){
  install_github('araastat/reprtree')
}
for(p in c(cran.packages, 'reprtree')) eval(substitute(library(pkg), list(pkg=p)))

然后继续制作你的模型和树:

library(randomForest)
library(reprtree)

model <- randomForest(Species ~ ., data=iris, importance=TRUE, ntree=500, mtry = 2, do.trace=100)

reprtree:::plot.getTree(model)

你去吧!美丽而简单。

从 plot.getTree(model) 生成的树

您可以查看 github repo 以了解包中的其他方法。实际上,如果您检查plot.getTree.R,您会注意到作者使用他自己的实现as.tree()chl♦建议您可以在他的答案中建立自己。这意味着您可以这样做:

tree <- getTree(model, k=1, labelVar=TRUE)
realtree <- reprtree:::as.tree(tree, model)

然后可能realtree与其他树绘图包一起使用,例如tree

我创建了一些函数来提取树的规则。

#**************************
#return the rules of a tree
#**************************
getConds<-function(tree){
  #store all conditions into a list
  conds<-list()
  #start by the terminal nodes and find previous conditions
  id.leafs<-which(tree$status==-1)
	  j<-0
	  for(i in id.leafs){
		j<-j+1
		prevConds<-prevCond(tree,i)
		conds[[j]]<-prevConds$cond
		while(prevConds$id>1){
		  prevConds<-prevCond(tree,prevConds$id)
		  conds[[j]]<-paste(conds[[j]]," & ",prevConds$cond)
        }
		if(prevConds$id==1){
			conds[[j]]<-paste(conds[[j]]," => ",tree$prediction[i])
    }
    }

  }

  return(conds)
}

#**************************
#find the previous conditions in the tree
#**************************
prevCond<-function(tree,i){
  if(i %in% tree$right_daughter){
		id<-which(tree$right_daughter==i)
		cond<-paste(tree$split_var[id],">",tree$split_point[id])
	  }
	  if(i %in% tree$left_daughter){
    id<-which(tree$left_daughter==i)
		cond<-paste(tree$split_var[id],"<",tree$split_point[id])
  }

  return(list(cond=cond,id=id))
}

#remove spaces in a word
collapse<-function(x){
  x<-sub(" ","_",x)

  return(x)
}


data(iris)
require(randomForest)
mod.rf <- randomForest(Species ~ ., data=iris)
tree<-getTree(mod.rf, k=1, labelVar=TRUE)
#rename the name of the column
colnames(tree)<-sapply(colnames(tree),collapse)
rules<-getConds(tree)
print(rules)