是否有一种灵活的方法可以从混淆矩阵的每个单元格中获取原始数据索引?

数据挖掘 机器学习 Python scikit-学习 机器学习模型 混淆矩阵
2022-03-07 03:21:15

假设我有模型 A 和模型 B,我想通过检查混淆矩阵来比较它们的性能。它们都产生一个预测列表,pred_A并且pred_B对应于一组真实值列表ground_truths我可以使用类似的东西sklearn.metrics.confusion_matrix来为每个模型生成混淆矩阵:

cm_A = confusion_matrix(ground_truths, pred_A) # e.g. [[5,1], [2,3]]
cm_B = confusion_matrix(ground_truths, pred_B) # e.g. [[6,0], [3,2]]

我想弄清楚的是如何从混淆矩阵的每个单元格中获取特定的索引,对应于地面实况类。这将让我调查哪些特定数据点是真阳性、假阳性、假阴性和真阴性,并让我根据它们的输入特征搜索模式。在上面的示例中,两个假阴性是否与和中的假阴性相同还是在两个模型预测之间所有的真阳性都与所有的假阴性都发生了转换?如果我知道他们在列表中的索引,我可以很容易地回答这个问题。cm_Acm_Bground_truths

我可以想到愚蠢/低效的方法来做到这一点,但我很想找到一个可以扩展到任意数量的标签的解决方案。

1个回答

您绝对可以获取此信息,但不能从混淆矩阵中获取。您想要比较预测向量本身,而不是混淆矩阵,因为正如您正确识别的那样,混淆矩阵将所有假正/负转储到相同的桶中(这样我们也可以看到每个桶的填充量有用的信息)。

您可以对预测向量进行各种比较以获得所需的信息(假设您使用的是 numpy 数组):

# Which examples did classifier A get wrong?
A_mistakes = np.invert(pred_A == ground_truths)

# Which examples did classifier B get wrong?
B_mistakes = np.invert(pred_B == ground_truths)

# Where did the classifiers make the same mistakes?
common_mistakes = A_mistakes == B_mistakes

# Where did the classifiers make a wrong prediction that the other didn't?
unique_mistakes = np.logical_xor(A_mistakes, B_mistakes)

然后,您可以将独特的错误与 A_mistakes 或 B_mistakes 进行比较,以将它们追溯到分类器(或您感兴趣的任何其他内容)。您也可以只np.sum使用二进制向量来计算唯一错误、常见错误等的数量。