多类 CNN 的 ROC 曲线上的 AUC 接近 1.0,但精度/召回率并不完美?

数据挖掘 Python 分类 多类分类 奥克
2022-02-18 15:09:12

我正在使用 CNN 在 CIFAR-10 数据集上构建 ROC 曲线并计算多类分类的 AUC。我的整体准确率约为 90%,我的准确率和召回率如下:

              precision    recall  f1-score   support

    airplane       0.93      0.90      0.91      1000
  automobile       0.93      0.96      0.95      1000
        bird       0.88      0.87      0.87      1000
         cat       0.86      0.72      0.79      1000
        deer       0.88      0.91      0.89      1000
         dog       0.88      0.81      0.84      1000
        frog       0.83      0.97      0.89      1000
       horse       0.94      0.94      0.94      1000
        ship       0.95      0.93      0.94      1000
       truck       0.90      0.95      0.92      1000

    accuracy                           0.90     10000
   macro avg       0.90      0.90      0.90     10000
weighted avg       0.90      0.90      0.90     10000

ROC曲线

我计算ROC Curveand的代码AUC如下:

def assess_model_from_pb(model_file_path: Path, xtest: np.ndarray, ytest: np.ndarray, save_plot_path: Path):

    class_labels = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
    model = load_model(model_file_path) # load model from filepath
    feature_extractor = Model(inputs = model.inputs, outputs = model.get_layer('dense').output) # extract dense output layer (will be softmax probabilities)
    y_score = feature_extractor.predict(xtest, batch_size = 64) # one hot encoded softmax predictions
    ytest_binary = label_binarize(ytest, classes = [0,1,2,3,4,5,6,7,8,9]) # one hot encode the test data true labels
    n_classes = y_score.shape[2]

    fpr = dict()
    tpr = dict()
    roc_auc = dict() 
    # compute fpr and tpr with roc_curve from the ytest true labels to the scores
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(ytest_binary[:, i], y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # plot each class  curve on single graph for multi-class one vs all classification
    colors = cycle(['blue', 'red', 'green', 'brown', 'purple', 'pink', 'orange', 'black', 'yellow', 'cyan'])
    for i, color, lbl in zip(range(n_classes), colors, class_labels):
        plt.plot(fpr[i], tpr[i], color = color, lw = 1.5,
        label = 'ROC Curve of class {0} (area = {1:0.3f})'.format(lbl, roc_auc[i]))
    plt.plot([0, 1], [0, 1], 'k--', lw = 1.5)
    plt.xlim([-0.05, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve for CIFAR-10 Multi-Class Data')
    plt.legend(loc = 'lower right', prop = {'size': 6})
    fullpath = save_plot_path.joinpath(save_plot_path.stem +'_roc_curve.png')
    plt.savefig(fullpath)
    plt.show()

我想当我的精确度和召回率不完美时,我对我的 AUC 如何接近 1 感到困惑。我知道许多thresholds用于确定什么是正类和负类。例如,在曲线的开始处,如果阈值非常高(比如 0.99999 左右),那么 mytpr是如何接近 1 的?仅仅是因为在那个阈值下,我只会对绝对最高的 softmax 概率给出正面分类吗?

只想对这个话题有更多的解释或直觉,以确保我没有做错什么。

0个回答
没有发现任何回复~