SGDClassifier - 为什么我需要使用 argmax 而不是 argmin 来找到满足给定精度的最低阈值?

数据挖掘 Python 分类 scikit-学习 麻木的
2022-02-15 12:29:16

我是一位经验丰富的程序员,但对 Python 和数据科学不熟悉。我正在关注 Aurelien Gerone 的书,但我不明白一件事。

我创建 SGDClassifier 并计算它的precision_recall_curve()。然后我试图找到满足精度等于 90% 的最低阈值:

precisions, recalls, thresholds = precision_recall_curve(y_train, y_scores)
threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]

如果我需要找到最小阈值,为什么我要搜索 arg max ?如果我尝试使用 argmin,我会得到错误的值,精度等于 0.1。

据我了解:

  • 精度 >= 0.90 创建一个精度分数仅高于或等于 0.90 的数组,
  • argmax 返回一个索引,在该索引处我找到给定数组中的最大值(所以这应该尽可能远离 90%,但事实并非如此!),
  • 然后我选择一个返回索引的阈值。

我错过了什么?

1个回答

好的,我自己解决了这个问题。

precisions >= 0.90 不会创建精度分数仅高于 90% 的数组,而是将此数组转换为布尔数组,其中 90% 以下的精度将变为 False,其他为 True。

argmax,如果有多个相同的最大值(此处为最大值)返回此事件的第一个索引

我有时讨厌这本书,为什么他不使用“array.first_equals(True)”之类的方法?