如何在 gridsearchSV 中计算 AUC(多类问题)

数据挖掘 scikit-学习 多类分类 网格搜索 计分 奥克
2022-03-12 12:29:08

我正在研究一个多类分类问题,比较 SVM 和随机森林分类器的结果。我想使用 gridsearchCV 进行超参数调优,发现 AUC 是此类问题最常用的指标。

我知道如何使用其他评分指标,如准确性等,但默认的 ROC_AUC 仅适用于二进制类。有没有一种方法可以在 gridsearchCV 中使用 AUC 来解决多类问题?

2个回答

指标独立于 ML 算法,因此您使用哪种算法并不重要。

要计算多类 AUC ,您可以在R中使用 lib pRoc或使用此链接的代码(在 Python 中)。

资料来源:

sklearnroc_auc_score实际上确实使用其averagemulticlass参数处理多类和多标签问题。默认average='macro'值很好,但您应该考虑替代方案。但是multiclass='raise'需要覆盖默认值。要在 a 中使用它GridSearchCV,您可以对函数进行 curry,例如

import functools

multiclass_roc_auc = functools.partial(roc_auc_score, multiclass='ovr')
search = GridSearchCV(estimator=...,
                      eval=multiclass_roc_auc,
                      ...)