随机森林相对于最大深度的复杂性

机器算法验证 随机森林 算法 scikit-学习 时间复杂度
2022-04-05 05:58:27

我希望随机森林的训练时间是:在树的数量上呈线性(显然),在列数上呈线性(或平方根)(取决于对列的子采样的选择),并且当它来到树的最大深度......我迷路了。

我认为找到分裂(在叶子内)与叶子中的元素数量成线性关系。

因此,增加树的最大深度只是找到更多的分裂,其中是之前的叶子数附加操作是(是叶的大小): nnniLi

i=1nsplit(Li)=i=1nθ(ni)=θ(n)

这或多或少等同于找到第一个拆分。

因此,我认为随机森林的训练时间在最大深度是线性的。然而,数字示例显示(使用 sk-learn 执行)我完全错了(希望我能够发布深度 20 时间):

TIME 0.15m (1-fold)
{'n_estimators': 10, 'max_depth': 5}

TIME 0.60m (1-fold)
{'n_estimators': 50, 'max_depth': 5}

TIME 0.69m (1-fold)
{'n_estimators': 10,'max_depth': 10}

TIME 3.18m (1-fold)
{'n_estimators': 50, 'max_depth': 10}

TIME 3.85m (1-fold)
{'n_estimators': 10, 'max_depth': 15}

TIME 15.59m (1-fold)
{'n_estimators': 50, 'max_depth': 15}
1个回答

对于下面模拟的较小数据集,该过程应该是线性的。正如@EngrStudent 所指出的,这可能是 L1、L2 和 RAM 时钟速度的问题。随着模型复杂性的增加,随机森林算法可能无法计算 L1 和/或 L2 缓存中的整个树(...或树的子分支)。

我尝试使用 R randomForest 进行类似的测试,实际上它似乎是线性的。我不能在 randomForest 中选择 maxdepth,而只能选择最大终端节点(maxnodes),但这实际上是相同的。

最大终端节点 =2(maxdepth1)

请注意,我按对数刻度绘制 maxnodes (1,2,4,8,16,32,64),然后按 x 轴线性绘制深度 (0,1,2,3,4,5,6)。时间消耗似乎随着深度线性增加。

在此处输入图像描述

library(randomForest)
library(ggplot2)
set.seed(1)

#make some data
vars=10
obs = 4000
X = data.frame(replicate(vars,rnorm(obs)))
y = with(X, X1+sin(X2*2*pi)+X3*X4)

#wrapper function to time a model
time_model = function(model_function,...) {
  this_time = system.time({this_model_obj = do.call(model_function,list(...))})
  this_time['elapsed']
}

#generate jobs to simulate, jobs are sets of parameters (pars)
fixed_pars = alist(model_function=randomForest,x=X,y=y) #unevaluated to save memory
iter_pars = list(maxnodes=c(1,2,4,8,16,32,64),ntree = c(10,25,50),rep=c(1:5))
iter_pars_matrix = do.call(expand.grid,iter_pars)

#combine fixed and iterative pars and shape as list of jobs
job_list = apply(iter_pars_matrix,1,c,fixed_pars)

#do jobs and collect results in a data.frame
times = sapply(job_list,function(aJob) do.call(time_model,aJob))
r_df = data.frame(times,iter_pars_matrix)

#plot the results
ggplot(r_df, aes (x = maxnodes,y = times,colour = factor(ntree))) +
geom_point() + scale_x_log10()