GridSearchCV 在完成评估需要这么长时间的参数组合的性能后在做什么?

数据挖掘 Python scikit-学习 交叉验证 网格搜索
2021-10-14 13:52:24

我正在运行 GridSearchCV 来调整一些参数。例如:

params = {
    'max_depth':[18,21]
}

gscv = GridSearchCV(
    xgbc,
    params,
    scoring='roc_auc',
    verbose=50,
    cv=StratifiedKFold(n_splits=2, shuffle=True,random_state=42)
)

gscv.fit(df.drop('LAPSED', axis=1), df.LAPSED)
print('best score: ', gscv.best_score_, 'best params: ', gscv.best_params_)

一切都好。因为我已经指定了一些详细信息,所以它会输出一些关于它正在做什么的信息,如下所示:

Fitting 2 folds for each of 2 candidates, totalling 4 fits
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[CV] max_depth=18 ....................................................
[CV] ........... max_depth=18, score=0.9453140690301272, total= 8.2min
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:  8.3min remaining:    0.0s
[CV] max_depth=18 ....................................................
[CV] ........... max_depth=18, score=0.9444119097669363, total= 7.9min
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed: 16.3min remaining:    0.0s
[CV] max_depth=21 ....................................................
[CV] ........... max_depth=21, score=0.9454705777130412, total= 8.4min
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed: 24.8min remaining:    0.0s
[CV] max_depth=21 ....................................................
[CV] ........... max_depth=21, score=0.9443863821843195, total= 8.3min
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed: 33.2min remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed: 33.2min finished

然而,一旦它完成了所有折叠的运行,它需要很长时间(至少只要它需要为一个参数组合拟合和评估一个折叠)才能将输出返回到print('best score: ', gscv.best_score_, 'best params: ', gscv.best_params_),即使我可以手动计算很容易从它作为拟合过程的一部分输出的数据中。我认为这意味着该算法在完成拟合和评估不同模型后会挂起做其他事情,但我不确定那可能是什么。

nb 实际上,我突然想到,这可能是花时间在模型确定为提供最佳性能的参数上重新训练模型,以便它可用于.predict()etc 方法。我现在只是通过传递refit=False来检查它以防止这种情况发生,如果它有效,我将回答我自己的问题。

1个回答

是的,我想通了。答案是默认情况下 GridSearchCV 的最后一步是公开您传递的估计器对象的 API,以便您可以直接调用GridSearchCV 对象本身之类的.predict()东西。.score()它通过针对交叉验证期间找到的最佳参数重新训练估计器来做到这一点。如果您想跳过此步骤(例如,因为之后您将继续进行更多开发或交叉验证),那么您可以通过refit=False以防止这种情况发生。