如何限制要传递给 PMML 管道中的最终分类器的列

数据挖掘 Python scikit-学习 管道
2022-02-14 16:02:07

我正在使用 SKLearn 和 SKLearn2PMML 构建 XGBoost PMML。我有一些数字、一些分类和日期时间列,我在管道中创建新功能。当我尝试训练模型时,它会失败,因为默认情况下原始分类特征也会传递给最终分类器。有没有办法通过指定功能名称来限制功能?

1个回答

在挖掘了太多内容并得到了 sklearn2pmml 创建者的帮助后,我设法过滤了要传递给分类器的最终列。

注意:这里的记录器是 DataFrameMapper 对象。

1.获取分类列索引。

cat_cols = [recorder.transformed_names_.index(c) for c in categoricalCols if c in recorder.transformed_names_]

2.添加 ColumnTransformer 以借助它们的索引过滤这些列。

pipeline = PMMLPipeline([
    ("mapper", recorder),
    ("select", ColumnTransformer([("drop", "drop", cat_cols)], remainder='passthrough')),
    ("classifier", xgb.XGBClassifier())
])

3.将数据拟合到管道中。

pipeline.fit(X_train,y_train)

4.从Pipeline中创建PMML文件。

out_file = "XGBoost.pmml"
sklearn2pmml(pipeline, out_file)