在 GridSearchCV 中使用 f1 分数

机器算法验证 机器学习 scikit-学习
2022-03-27 19:05:03

我想使用 sklearn.model_selection.GridSearchCV 使用 F1 分数指标进行交叉验证。我的问题是一个多类分类问题。我想在 F1 分数中使用选项 average='micro'。

另见: https ://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score

我已经检查了以下帖子: https ://stackoverflow.com/questions/34221712/grid-search-with-f1-as-scoring-function-several-pages-of-error-message

如果我确切地尝试这篇文章中的内容,但我总是会收到此错误:

TypeError: f1_score() missing 2 required positional arguments: 'y_true' and 'y_pred'

我的问题基本上只是关于语法:如何在 GridSearchCV 中使用带有 average='micro' 的 f1_score?

我将非常感谢任何答案。

编辑:这是一个可执行的例子:

import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.metrics import f1_score, make_scorer
from sklearn.preprocessing import RobustScaler
from sklearn.svm import SVC


data = load_breast_cancer()
X = data['data']
y = data['target']

#X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)

scaler = RobustScaler()
estimator = SVC()

pipeline_steps = [('scaler', scaler), ('estimator', estimator)]
pipeline_steps = Pipeline(steps=pipeline_steps)

grid = [{'estimator__C': [0.1, 0.5, 1.5, 2, 2.5, 3]}]

gridsearch = GridSearchCV(estimator=pipeline_steps,
                       param_grid=grid,
                       n_jobs=-1,
                       cv=5,
                       scoring=make_scorer(f1_score(average='micro')))

# now perform full fit on whole pipeline
gridsearch.fit(X, y)
print("Best parameters from gridsearch: {}".format(gridsearch.best_params_))
print("CV score=%0.3f" % gridsearch.best_score_)
cv_results = gridsearch.cv_results_
#print(cv_results)
3个回答

好的,我发现了:

如果您根据https://scikit-learn.org/stable/modules/model_evaluation.html使用 score='f1_micro' ,您将得到我想要的。

您可以按照此处提供的示例进行操作,只需将 average='micro' 传递给 make_scorer。https://scikit-learn.org/stable/modules/generated/sklearn.metrics.make_scorer.html

gridsearch = GridSearchCV(estimator=pipeline_steps, param_grid=grid, n_jobs=-1, cv=5, score='f1_micro')

您可以检查以下链接并使用分类列中的所有评分。

链接:https ://scikit-learn.org/stable/modules/model_evaluation.html