无论如何要知道在 scikit-learn 中使用 RandomForestClassifier 种植的树木的所有细节?

数据挖掘 scikit-学习 随机森林 决策树
2021-10-07 20:04:51

我正在使用 scikit-learn 包构建一个标准的 RandomForest 分类器(命名模型,请参见下面的代码)。现在,我想获取一个 Randomforest 分类器的所有参数(包括它的树(估计器)),以便我可以手动绘制 RandomForest 分类器的每棵树的流程图。我想知道是否有人知道如何做到这一点?

先感谢您。

#Import Library
from sklearn.ensemble import RandomForestClassifier #use RandomForestRegressor for regression problem
#Assumed you have, X (predictor) and Y (target) for training data set and x_test(predictor) of test_dataset
# Create Random Forest object
model= RandomForestClassifier(n_estimators=10, max_depth=5) #n_estimators=1000 oob_score = True
#====
#X, y = input_X, input_y

from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size = 0.2, random_state = 4)

# Train the model using the training sets and check score
model.fit(X_train, y_train)

#Predict Output
y_pred_train = model.predict(X_train)
y_pred_test = model.predict(X_test)

#accuracy
from sklearn.metrics import accuracy_score
print(accuracy_score(y_train,y_pred_train))
print(accuracy_score(y_test,y_pred_test))
3个回答

我认为特伦斯帕尔的回答现在已经部分过时了。您可以通过以下方式获得相同(甚至更多):

print("Tree depths: ", [t.get_depth() for t in model.estimators_])
print("Tree number of leaves: ", [t.get_n_leaves() for t in model.estimators_])

最大深度是一个非常有用的指标,我在 API 中没有找到,所以我写了这个:

def dectree_max_depth(tree):
    n_nodes = tree.node_count
    children_left = tree.children_left
    children_right = tree.children_right

    def walk(node_id):
        if (children_left[node_id] != children_right[node_id]):
            left_max = 1 + walk(children_left[node_id])
            right_max = 1 + walk(children_right[node_id])
            return max(left_max, right_max)
        else: # leaf
            return 1

    root_node_id = 0
    return walk(root_node_id)

您可以在森林中的所有树木上使用它 ( rf),如下所示:

[dectree_max_depth(t.tree_) for t in rf.estimators_]

您可以从随机森林中选择和可视化单个树:

# Extract individual tree from forest
tree_id = 5
tree = model.estimators_[tree_id]

# Draw individual tree flowchart
from sklearn.tree import export_graphviz

export_graphviz(tree)