如何将 Shap Kernal Explainer 与管道模型一起使用?

数据挖掘 机器学习 机器学习模型 数据科学模型 蟒蛇
2021-10-01 13:35:49

我有一个 pandas DataFrame X。我想找到一个特定模型的预测解释。

我的模型如下:

pipeline = Pipeline(steps= [
        ('imputer', imputer_function()),
        ('classifier', RandomForestClassifier()
    ])
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.30, random_state=0)
y_pred = pipeline.fit(x_train, y_train).predict(x_test)

现在对于预测解释器,我使用来自 Shap 的 Kernal Explainer。

这是以下内容:

# use Kernel SHAP to explain test set predictions
shap.initjs()

explainer = shap.KernelExplainer(pipeline.predict_proba, x_train, link="logit")

shap_values = explainer.shap_values(x_test, nsamples=10)

# # plot the SHAP values for the Setosa output of the first instance
shap.force_plot(explainer.expected_value[0], shap_values[0][0,:], x_test.iloc[0,:], link="logit")

当我运行代码时,我收到错误:

ValueError: Specifying the columns using strings is only supported for pandas DataFrames.

Provided model function fails when applied to the provided data set.

ValueError: Specifying the columns using strings is only supported for pandas DataFrames

谁能帮帮我吗?我真的坚持这一点。x_train 和 x_test 都是 pandas 数据帧。

2个回答

原因是内核 shap 将数据作为没有列名的 numpy 数组发送。所以我们需要修复它如下:

def model_predict(data_asarray):
    data_asframe =  pd.DataFrame(data_asarray, columns=feature_names)
    return estimator.predict(data_asframe)

然后,

shap_kernel_explainer = shap.KernelExplainer(model_predict, x_train, link='logit')
shap_values_single = shap_kernel_explainer.shap_values(x_test.iloc[0,:])
shap.force_plot(shap_kernel_explainer.expected_value[0],np.array(shap_values_single[0]), x_test.iloc[0,:],link='logit')

我尝试按照建议创建一个函数,但它不适用于我的代码。但是,正如Kaggle上的示例所建议的那样,我找到了以下解决方案:

import shap

#load JS vis in the notebook
shap.initjs() 

#set the tree explainer as the model of the pipeline
explainer = shap.TreeExplainer(pipeline['classifier'])

#apply the preprocessing to x_test
observations = pipeline['imputer'].transform(x_test)

#get Shap values from preprocessed data
shap_values = explainer.shap_values(observations)

#plot the feature importance
shap.summary_plot(shap_values, x_test, plot_type="bar")