使用随机森林 Sklearn 进行批量学习

机器算法验证 机器学习 随机森林 scikit-学习 大数据
2022-04-03 03:10:34

我有一个大约 500 万行的数据集,想运行一个 RandomForestClassifier。我只用 50 棵树运行了我的 RandomForestClassifier,我尝试使用 fit 函数,但收到内存错误。我尝试在具有 64GB 内存的 AWS 机器上运行它,但我仍然遇到了这个问题。

我想知道是否可以使用某种批量学习来使用 sklearn 来克服这个问题?如果有人有任何建议,我愿意接受其他建议。

1个回答

是的,在 scikit-learn 中批量学习当然是可能的。首次初始化 RandomForestClassifier 对象时,您需要将 warm_start 参数设置为 True。这意味着连续调用model.fit将不适合全新的模型,而是添加连续的树。

这里有一些伪代码可以帮助您入门。这将为数据的每个子块构建一棵树。

# split your data into an iterable of (X,y) pairs
# size each one so that it can fit into memory
data_splits = ... 

clf = RandomForestClassifier(warm_start = True, n_estimators = 1)
for _ in range(10): # 10 passes through the data
    for X, y in data_splits: 
        clf.fit(X,y)
        clf.n_estimators += 1 # increment by one so next  will add 1 tree

我很惊讶地发现subsampleRandomForestClassifier 中没有一个参数类似于 GradientBoostingClassifier 中的参数来控制每棵树可见的观察次数。如果您切换到 GradientBoostingClassifier,您可能可以简单地设置subsample为一个非常小的数字来获得相同的结果。