xgboost 中的 eta 和 learning_rate 不同

数据挖掘 Python xgboost
2022-03-01 02:18:12

我正在使用 python 中的 xgboost 创建分类模型。我正在使用不同的eta值来检查它对模型的影响。我的代码是-

for eta in np.arange(0.2, 0.51, 0.03):
    xgb_model = xgboost.XGBClassifier(objective = 'multi:softmax', num_class = 5, eta = eta)
    xgb_model.fit(x_train, y_train)
    xgb_out = xgb_model.predict(x_test)
    print("For eta %f, accuracy is %2.3f" %(eta,metrics.accuracy_score(y_test, xgb_out)*100))

我期望某些 eta 值具有不同的精度,但令我惊讶的是,每个 eta 的精度都相同。当我打印模型时,我得到了这个-

>>> print(xgb_model)
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bytree=1, eta=0.5, gamma=0, learning_rate=0.1,
       max_delta_step=0, max_depth=3, min_child_weight=1, missing=None,
       n_estimators=100, n_jobs=1, nthread=None, num_class=5,
       objective='multi:softprob', random_state=0, reg_alpha=0,
       reg_lambda=1, scale_pos_weight=1, seed=None, silent=True,
       subsample=1)

在这里你可以看到eta = 0.5,但是learning_rate = 0.1xgboost docs中, learning_rate 是 eta 的别名。那么这两者有不同的价值怎么可能呢?

1个回答

似乎 eta 只是一个占位符,尚未实现,而默认值仍然是learning_rate,基于源代码。接得好。

我们可以从sklearn.py 的源代码中看到,似乎存在一个名为“XGBModel”的类,它从 sklearn 的 API 继承了 BaseModel 的属性。

追踪到 compat.py,我们看到有一个 import 语句:

from sklearn.base import RegressorMixin, ClassifierMixin

这些是用于分类器/回归器的 sklearn 的 Mixin 类。基本上,它们是生成任何特定分类器/回归器的外壳,以保证后代类具有 score() 方法。

仔细查看 init 函数,我们看到:

__init__($self, /, *args, **kwargs)

来自 Yasoob 的 PythonTips(有关 *args/**kwargs 的示例,请参见链接):

*args 和 **kwargs 允许您将可变数量的参数传递给函数。这里的变量意味着您事先不知道用户可以将多少个参数传递给您的函数,因此在这种情况下您使用这两个关键字......

*args 用于向函数发送非关键字可变长度参数列表。

**kwargs 允许您将关键字可变长度的参数传递给函数。如果你想在函数中处理命名参数,你应该使用 **kwargs。

现在进行一个有趣的实验来证明这一点。在您自己的笔记本/控制台中尝试以下操作:

>>> XGBClassifier(eta=5, potato=1234)
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
   colsample_bytree=1, eta=5, gamma=0, learning_rate=0.1,
   max_delta_step=0, max_depth=3, min_child_weight=1, missing=None,
   n_estimators=100, n_jobs=1, nthread=None,
   objective='binary:logistic', potato=1234, random_state=0,
   reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
   silent=True, subsample=1)

看到我们的参数存储在对象中:

`...objective='binary:logistic', potato=1234, random_state=0,`

我们很确定potato这不是超参数,对吧?:)