XGBoost 似乎默认使用回归树作为基础学习器。XGBoost(或一般的梯度提升)通过组合多个这些基础学习器来工作。回归树无法推断训练数据中的模式,因此在您的情况下,任何高于 3 或低于 1 的输入都不会被正确预测。您的模型经过训练可以预测区间 中输入的输出,[1,3]
大于 3 的输入将获得与 3 相同的输出,小于 1 的输入将获得与 1 相同的输出。
此外,回归树并不真正将您的数据视为直线,因为它们是非参数模型,这意味着它们理论上可以拟合比直线更复杂的任何形状。粗略地说,回归树的工作原理是将您的新输入数据分配给它在训练期间看到的一些训练数据点,并据此生成输出。
这与参数回归器(如线性回归)形成对比,后者实际上寻找超平面的最佳参数(在您的情况下为直线)以适合您的数据。线性回归确实将您的数据视为具有斜率和截距的直线。
"booster":"gblinear"
您可以通过添加到您的模型将 XGBoost 模型的基础学习器更改为 GLM(广义线性模型)params
:
import pandas as pd
import xgboost as xgb
df = pd.DataFrame({'x':[1,2,3], 'y':[10,20,30]})
X_train = df.drop('y',axis=1)
Y_train = df['y']
T_train_xgb = xgb.DMatrix(X_train, Y_train)
params = {"objective": "reg:linear", "booster":"gblinear"}
gbm = xgb.train(dtrain=T_train_xgb,params=params)
Y_pred = gbm.predict(xgb.DMatrix(pd.DataFrame({'x':[4,5]})))
print Y_pred
通常,要调试 XGBoost 模型为何以特定方式运行,请参阅模型参数:
gbm.get_dump()
如果您的基础学习器是线性模型,则 get_dump 输出为:
['bias:\n4.49469\nweight:\n7.85942\n']
在上面的代码中,由于您创建了基础学习器,因此输出将是:
['0:[x<3] yes=1,no=2,missing=1\n\t1:[x<2] yes=3,no=4,missing=3\n\t\t3:leaf=2.85\n\t\t4:leaf=5.85\n\t2:leaf=8.85\n',
'0:[x<3] yes=1,no=2,missing=1\n\t1:[x<2] yes=3,no=4,missing=3\n\t\t3:leaf=1.995\n\t\t4:leaf=4.095\n\t2:leaf=6.195\n',
'0:[x<3] yes=1,no=2,missing=1\n\t1:[x<2] yes=3,no=4,missing=3\n\t\t3:leaf=1.3965\n\t\t4:leaf=2.8665\n\t2:leaf=4.3365\n',
'0:[x<3] yes=1,no=2,missing=1\n\t1:[x<2] yes=3,no=4,missing=3\n\t\t3:leaf=0.97755\n\t\t4:leaf=2.00655\n\t2:leaf=3.03555\n',
'0:[x<3] yes=1,no=2,missing=1\n\t1:[x<2] yes=3,no=4,missing=3\n\t\t3:leaf=0.684285\n\t\t4:leaf=1.40458\n\t2:leaf=2.12489\n',
'0:[x<3] yes=1,no=2,missing=1\n\t1:[x<2] yes=3,no=4,missing=3\n\t\t3:leaf=0.478999\n\t\t4:leaf=0.983209\n\t2:leaf=1.48742\n',
'0:[x<3] yes=1,no=2,missing=1\n\t1:[x<2] yes=3,no=4,missing=3\n\t\t3:leaf=0.3353\n\t\t4:leaf=0.688247\n\t2:leaf=1.04119\n',
'0:[x<3] yes=1,no=2,missing=1\n\t1:[x<2] yes=3,no=4,missing=3\n\t\t3:leaf=0.23471\n\t\t4:leaf=0.481773\n\t2:leaf=0.728836\n',
'0:[x<3] yes=1,no=2,missing=1\n\t1:[x<2] yes=3,no=4,missing=3\n\t\t3:leaf=0.164297\n\t\t4:leaf=0.337241\n\t2:leaf=0.510185\n',
'0:[x<2] yes=1,no=2,missing=1\n\t1:leaf=0.115008\n\t2:[x<3] yes=3,no=4,missing=3\n\t\t3:leaf=0.236069\n\t\t4:leaf=0.357129\n']
提示:我实际上更喜欢使用 xgb.XGBRegressor 或 xgb.XGBClassifier 类,因为它们遵循sci-kit learn API。并且因为 sci-kit learn 有这么多机器学习算法实现,所以使用 XGB 作为附加库不会干扰我的工作流程,只有当我使用 XGBoost 的 sci-kit 接口时。