我正在使用 CNN 在 CIFAR-10 数据集上构建 ROC 曲线并计算多类分类的 AUC。我的整体准确率约为 90%,我的准确率和召回率如下:
precision recall f1-score support
airplane 0.93 0.90 0.91 1000
automobile 0.93 0.96 0.95 1000
bird 0.88 0.87 0.87 1000
cat 0.86 0.72 0.79 1000
deer 0.88 0.91 0.89 1000
dog 0.88 0.81 0.84 1000
frog 0.83 0.97 0.89 1000
horse 0.94 0.94 0.94 1000
ship 0.95 0.93 0.94 1000
truck 0.90 0.95 0.92 1000
accuracy 0.90 10000
macro avg 0.90 0.90 0.90 10000
weighted avg 0.90 0.90 0.90 10000
我计算ROC Curveand的代码AUC如下:
def assess_model_from_pb(model_file_path: Path, xtest: np.ndarray, ytest: np.ndarray, save_plot_path: Path):
class_labels = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
model = load_model(model_file_path) # load model from filepath
feature_extractor = Model(inputs = model.inputs, outputs = model.get_layer('dense').output) # extract dense output layer (will be softmax probabilities)
y_score = feature_extractor.predict(xtest, batch_size = 64) # one hot encoded softmax predictions
ytest_binary = label_binarize(ytest, classes = [0,1,2,3,4,5,6,7,8,9]) # one hot encode the test data true labels
n_classes = y_score.shape[2]
fpr = dict()
tpr = dict()
roc_auc = dict()
# compute fpr and tpr with roc_curve from the ytest true labels to the scores
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(ytest_binary[:, i], y_score[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# plot each class curve on single graph for multi-class one vs all classification
colors = cycle(['blue', 'red', 'green', 'brown', 'purple', 'pink', 'orange', 'black', 'yellow', 'cyan'])
for i, color, lbl in zip(range(n_classes), colors, class_labels):
plt.plot(fpr[i], tpr[i], color = color, lw = 1.5,
label = 'ROC Curve of class {0} (area = {1:0.3f})'.format(lbl, roc_auc[i]))
plt.plot([0, 1], [0, 1], 'k--', lw = 1.5)
plt.xlim([-0.05, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve for CIFAR-10 Multi-Class Data')
plt.legend(loc = 'lower right', prop = {'size': 6})
fullpath = save_plot_path.joinpath(save_plot_path.stem +'_roc_curve.png')
plt.savefig(fullpath)
plt.show()
我想当我的精确度和召回率不完美时,我对我的 AUC 如何接近 1 感到困惑。我知道许多thresholds用于确定什么是正类和负类。例如,在曲线的开始处,如果阈值非常高(比如 0.99999 左右),那么 mytpr是如何接近 1 的?仅仅是因为在那个阈值下,我只会对绝对最高的 softmax 概率给出正面分类吗?
只想对这个话题有更多的解释或直觉,以确保我没有做错什么。
