如何确定分类器的最佳阈值并生成 ROC 曲线?

机器算法验证 机器学习 支持向量机
2022-01-21 11:24:02

假设我们有一个 SVM 分类器,我们如何生成 ROC 曲线?(就像理论上一样)(因为我们使用每个阈值生成 TPR 和 FPR)。我们如何确定这个 SVM 分类器的最佳阈值?

4个回答

使用 SVM 分类器对一组带注释的示例进行分类,基于示例的一个预测可以识别 ROC 空间上的“一个点”。假设示例数为200,首先统计四个案例的示例数。

labeled truelabeled falsepredicted true7128predicted false5744


然后计算 TPR(真阳性率)和 FPR(假阳性率)。TPR=71/(71+57)=0.5547, 和FPR=28/(28+44)=0.3889在 ROC 空间上,x 轴为 FPR,y 轴为 TPR。所以点(0.3889,0.5547)获得。

要绘制 ROC 曲线,只需

  1. 调整一些阈值来控制标记为真或假的示例数量
    例如,如果某些蛋白质的浓度高于 α% 表示疾病,则不同的 α 值会产生不同的最终 TPR 和 FPR 值。可以通过类似于网格搜索的方式简单地确定阈值;用不同的阈值标注训练样例,用不同的标注样例集训练分类器,在测试数据上运行分类器,计算FPR值,选择覆盖低(接近0)和高(接近1)FPR的阈值值,即接近 0, 0.05, 0.1, ..., 0.95, 1
  2. 生成多组带注释的示例
  3. 在示例集上运行分类器
  4. 为它们中的每一个计算一个 (FPR, TPR) 点
  5. 绘制最终的 ROC 曲线

可以在http://en.wikipedia.org/wiki/Receiver_operating_characteristic中查看一些详细信息。

此外,这两个链接对于如何确定最佳阈值很有用。一种简单的方法是取真阳性和假阴性率之和最大的那个。其他更精细的标准可能包括涉及不同阈值的其他变量,例如财务成本等。
http://www.medicalbiostatistics.com/roccurve.pdf
http://www.kovcomp.co.uk/support/XL-Tut/life-ROC -curves-receiver-operating-characteristic.html

阈值的选择取决于 TPR 和 FPR 分类问题的重要性。例如,如果您的分类器将决定哪些犯罪嫌疑人将被判处死刑,则误报非常糟糕(无辜者将被杀死!)。因此,您将选择一个在保持合理 TPR 的同时产生低 FPR 的阈值(这样您实际上会抓住一些真正的罪犯)。如果没有关于低 TPR 或高 FPR 的外部担忧,一种选择是通过选择最大化的阈值来平等地加权它们TPRFPR.

选择最接近 ROC 空间左上角的点。现在用于生成该点的阈值应该是最佳阈值。

####################################

最佳截止将是 tpr 高而 fpr 低

tpr - (1-fpr) 为零或接近于零是最佳截止点

####################################

def plot_roc_curve(fpr, tpr):
    plt.plot(fpr, tpr, color='orange', label='ROC')
    plt.plot([0, 1], [0, 1], color='darkblue', linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend()
    plt.show()

y_true = np.array([0,0, 1, 1,1])
y_scores = np.array([0.0,0.09, .05, .75,1])

fpr, tpr, thresholds = roc_curve(y_true, y_scores)
print(tpr)
print(fpr)
print(thresholds)
print(roc_auc_score(y_true, y_scores))
optimal_idx = np.argmax(tpr - fpr)
optimal_threshold = thresholds[optimal_idx]
print("Threshold value is:", optimal_threshold)
plot_roc_curve(fpr, tpr)

阈值为:0.75