我正在使用 python 的 scikit learn 中的决策树。与许多用例不同,此时我对分类器的准确性不太感兴趣,因为我正在提取数据点在调用.predict()它时穿过树的特定路径。有没有人这样做过?我想构建一个包含(,路径)对的数据框,用于下游分析。
通过sklearn中的决策树提取数据点的“路径”
数据挖掘
机器学习
Python
scikit-学习
决策树
2022-02-08 17:32:33
1个回答
看起来这在 R 中更容易做到,将rpart库与库结合使用partykit。理想情况下,我希望找到一种在 python 中执行此操作的方法,但这里是代码,供任何感兴趣的人使用(取自此处):
pathpred <- function(object, ...){
## coerce to "party" object if necessary
if(!inherits(object, "party")) object <- as.party(object)
## get standard predictions (response/prob) and collect in data frame
rval <- data.frame(response = predict(object, type = "response", ...))
rval$prob <- predict(object, type = "prob", ...)
## get rules for each node
rls <- partykit:::.list.rules.party(object)
## get predicted node and select corresponding rule
rval$rule <- rls[as.character(predict(object, type = "node", ...))]
return(rval)
}
使用 iris 数据和 rpart() 的说明:
library("rpart")
library("partykit")
rp <- rpart(Species ~ ., data = iris)
rp_pred <- pathpred(rp)
rp_pred[c(1, 51, 101), ]
屈服,
response prob.setosa prob.versicolor prob.virginica
1 setosa 1.00000000 0.00000000 0.00000000
51 versicolor 0.00000000 0.90740741 0.09259259
101 virginica 0.00000000 0.02173913 0.97826087
rule
1 Petal.Length < 2.45
51 Petal.Length >= 2.45 & Petal.Width < 1.75
101 Petal.Length >= 2.45 & Petal.Width >= 1.75
这看起来是我至少可以用来派生共享父节点信息的东西。
其它你可能感兴趣的问题