如何打印 RandomForestClassifier 的决策树

机器算法验证 Python 随机森林 scikit-学习
2022-03-03 18:13:34

最近,我注意到这里sklearn.tree.export_graphviz记录了一种方法

但是,我不知道如何将其应用于RandomForestClassifier.

我尝试了以下幼稚的代码,但它不起作用,而且我不知道如何从 a 中获取其中一棵树RandomForestClassifier

print('Training...')
forest = RandomForestClassifier(n_estimators=100)
forest = forest.fit( train_data[0::,1::], train_data[0::,0] )


print('Predicting...')
output = forest.predict(test_data).astype(int)

if sys.version_info >= (3,0,0):
    predictions_file = open("myfirstforest.csv", 'w', newline='')
else:
    predictions_file = open("myfirstforest.csv", 'wb')


tree.export_graphviz(forest, out_file='tree.dot')
1个回答

您的 RandomForest 创建了 100 棵树,因此您无法一步打印这些。尝试遍历森林中的树木并一一打印出来:

from sklearn import tree
i_tree = 0
for tree_in_forest in forest.estimators_:
    with open('tree_' + str(i_tree) + '.dot', 'w') as my_file:
        my_file = tree.export_graphviz(tree_in_forest, out_file = my_file)
    i_tree = i_tree + 1

如果你想知道树的实际参数,如分裂属性(特征)、分裂值(阈值)、节点样本(n_node_samples)等,可以print getmembers(tree_in_forest.tree_)在for循环中使用。要使用这些参数之一,例如。阈值,使用这个:tree_in_forest.tree_.threshold它返回一个列表。