sklearn 的 cross_validate 不适用于 catboost

数据挖掘 scikit-学习 交叉验证 助推
2021-10-04 16:27:04

我想将交叉验证与catboost. 由于我不仅想使用catboost而且还采样,所以我正在使用管道,因此不能使用catboost's自己的交叉验证(如果我只使用catboost而不是管道,则可以使用)。所以我想使用sklearn's交叉验证,如果我只使用数字变量,它就可以正常工作,但只要我还包括分类变量(cat_features)并使用catboost's编码,cross_validate就不再工作了。即使我不使用管道,而只是catboost单独使用,我也会收到一条KeyError: 0带有cross_validate. 但我不明白为什么。这是我不起作用的代码的一部分:

from sklearn.model_selection import cross_validate
model = cb.CatBoostClassifier(**params, cat_features=cat_features)
cv_score = cross_validate(model, X_train, y_train, scoring='roc_auc', cv=5, return_train_score=True)
1个回答

我发现添加xtrain.astype('O')工作。

显然 catboost 不适用于 pandas Categorical dtypes:https ://github.com/catboost/catboost/issues/814