通过sklearn中的决策树提取数据点的“路径”

数据挖掘 机器学习 Python scikit-学习 决策树
2022-02-08 17:32:33

我正在使用 python 的 scikit learn 中的决策树。与许多用例不同,此时我对分类器的准确性不太感兴趣,因为我正在提取数据点在调用.predict()它时穿过树的特定路径。有没有人这样做过?我想构建一个包含(,路径)对的数据框,用于下游分析。Xii

1个回答

看起来这在 R 中更容易做到,将rpart库与库结合使用partykit理想情况下,我希望找到一种在 python 中执行此操作的方法,但这里是代码,供任何感兴趣的人使用(取自此处):

pathpred <- function(object, ...){
    ## coerce to "party" object if necessary
    if(!inherits(object, "party")) object <- as.party(object)

    ## get standard predictions (response/prob) and collect in data frame
    rval <- data.frame(response = predict(object, type = "response", ...))
    rval$prob <- predict(object, type = "prob", ...)

    ## get rules for each node
    rls <- partykit:::.list.rules.party(object)

    ## get predicted node and select corresponding rule
    rval$rule <- rls[as.character(predict(object, type = "node", ...))]

    return(rval)
}

使用 iris 数据和 rpart() 的说明:

library("rpart")
library("partykit")
rp <- rpart(Species ~ ., data = iris)
rp_pred <- pathpred(rp)
rp_pred[c(1, 51, 101), ]

屈服,

       response prob.setosa prob.versicolor prob.virginica
 1       setosa  1.00000000      0.00000000     0.00000000
 51  versicolor  0.00000000      0.90740741     0.09259259
 101  virginica  0.00000000      0.02173913     0.97826087
                                           rule
 1                          Petal.Length < 2.45
 51   Petal.Length >= 2.45 & Petal.Width < 1.75
 101 Petal.Length >= 2.45 & Petal.Width >= 1.75

这看起来是我至少可以用来派生共享父节点信息的东西。