如何确定最佳阈值以达到最高精度

机器算法验证 优化 临界点
2022-04-04 19:27:19

我有一个分类器在平衡数据集上输出的概率列表。我想最大化的指标是准确性()。给定概率及其真实标签,有没有办法计算最佳阈值(无需迭代许多阈值来选择最佳阈值)。TP+TNP+N

3个回答

我怀疑答案是否定的,即没有这种方法。

这是一个插图,我们根据真实标签绘制预测概率:

准确性

由于精度公式中的分母不变,因此您要做的是将水平红线向上或向下移动(高度是您感兴趣的阈值)以最大化 "线上方的“正”点加上线下方的“负”点的数量这条最佳线的位置完全取决于两个点云的形状,即每个真实标签的预测概率的条件分布。P+N

您最好的选择可能是二分搜索

也就是说,我建议你看看

同意@StephanKolassa,我将从算法的角度来看。如果您有数据样本,则需要根据产生的概率对样本进行排序,即然后,您真正的类标签将按 的顺序排列 然后,我们将放置一个分隔符在这个数组的某个位置;这将代表您的阈值。最多有位置可以放置它。即使你计算每个位置的准确度,也不会比排序复杂度差。在得到最大准确率后,阈值可以直接选为相邻样本的平均值。O(nlogn)n

0 0 1 0 0 1 ... 1 1 0 1
|n+1

我在 python 中实现了 Stephan Kolassa 提出的解决方案:

def opt_threshold_acc(y_true, y_pred):
    A = list(zip(y_true, y_pred))
    A = sorted(A, key=lambda x: x[1])
    total = len(A)
    tp = len([1 for x in A if x[0]==1])
    tn = 0
    th_acc = []
    for x in A:
        th = x[1]
        if x[0] == 1:
            tp -= 1
        else:
            tn += 1
        acc = (tp + tn) / total
        th_acc.append((th, acc))
    return max(th_acc, key=lambda x: x[1])