在一张图中绘制多条精确召回曲线

数据挖掘 Python scikit-学习 绘图
2022-02-20 07:32:35

我有一个不平衡的数据集,我正在阅读这篇文章,它研究了 SMOTE 和 RUS 以解决不平衡问题。所以我定义了以下3个模型:

    # AdaBoost
    ada = AdaBoostClassifier(n_estimators=100, random_state=42)
    ada.fit(X_train,y_train)
    y_pred_baseline = ada.predict(X_test) 
    
    # SMOTE    
    sm = SMOTE(random_state=42)
    X_train_sm, y_train_sm = sm.fit_sample(X_train, y_train)
    ada_sm = AdaBoostClassifier(n_estimators=100, random_state=42)
    ada_sm.fit(X_train_sm,y_train_sm)
    y_pred_sm = ada_sm.predict(X_test) 
    
    #RUS
    rus = RandomUnderSampler(random_state=42)
    X_train_rus, y_train_rus = rus.fit_resample(X, y)
    ada_rus = AdaBoostClassifier(n_estimators=100, random_state=42)
    ada_rus.fit(X_train_rus,y_train_rus)
    y_pred_rus = ada_rus.predict(X_test) 

然后我绘制了这 3 个模型的精确召回曲线。我选择这条曲线是因为我想可视化模型的表现,并且我对真正的否定不是很感兴趣(否定类是多数类)。

为了绘制曲线,我使用了 ScikitLearn 的 plot_precision_recall_curve 方法,如下所示:


    from sklearn.metrics import precision_recall_curve
    from sklearn.metrics import plot_precision_recall_curve
    import matplotlib.pyplot as plt
    
    disp = plot_precision_recall_curve(ada, X_test, y_test)
    disp.ax_.set_title('Precision-Recall curve')

    disp = plot_precision_recall_curve(ada_sm, X_test, y_test)
    disp.ax_.set_title('Precision-Recall curve')

    disp = plot_precision_recall_curve(ada_rus, X_test, y_test)
    disp.ax_.set_title('Precision-Recall curve')

这导致了 3 个单独的地块。

在此处输入图像描述

但是,我希望将这 3 条曲线放在一个图中,以便可以轻松比较它们。所以我想要一个像文章中的情节:

在此处输入图像描述

但我不确定如何执行此操作,因为 plot_precision_recall_curve 方法仅将一个分类器作为输入。

一些帮助将不胜感激。

1个回答

尝试以这种方式使用Matplotlib gca() 方法,您可以指示要在哪个轴上绘图

from sklearn.metrics import precision_recall_curve
from sklearn.metrics import plot_precision_recall_curve
import matplotlib.pyplot as plt

plot_precision_recall_curve(ada, X_test, y_test, ax = plt.gca(),name = "AdaBoost")

plot_precision_recall_curve(ada_sm, X_test, y_test, ax = plt.gca(),name = "SMOTE")

plot_precision_recall_curve(ada_rus, X_test, y_test, ax = plt.gca(),name = "RUS")

plt.title('Precision-Recall curve')