如何检测 OneVsRestClassifier 的匹配精度

数据挖掘 分类 nlp scikit-学习 多类分类
2022-03-08 12:38:48

我已经将我的文本分类改进为主题模块,从简单的 word2vec 到管道 tfidf 和 OneVsRestClassifier(使用 sklearn)。它确实改善了分类,但使用 word2vec 我能够计算每个主题的匹配百分比,并且使用 OneVsRestClassifier 我得到与特定主题的匹配或不匹配。有没有办法用 OneVsRestClassifier 查看分类的百分比是多少?

PS我不是在谈论评估训练的表现,而是实际的实时匹配百分比。

1个回答

是的当然。

假设您使用了 sklearn 的OneVsRestClassifier,因此您有一个决策函数,例如带有线性内核的支持向量分类器。用于将set_params更改为,默认为在 OneVsRestClassifier 分类器中使用它,然后使用内置函数,如probabilityTrueFalsepredict_proba

from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
mod = OneVsRestClassifier(SVC(kernel='linear').set_params(probability=True)).fit(samples,classes)
print mod.predict_proba(np.array([your_sample_vector]).reshape(1,-1))

编辑:

您可以使用旧LinearSVC的 withdecision_function来查找与超平面的距离并将它们转换为概率,例如

mod = OneVsRestClassifier(LinearSVC()).fit(sample,clas)
proba = (1./(1.+np.exp(-mod.decision_function(np.array(your_test_array).reshape(1,-1)))))
proba /= proba.sum(axis=1).reshape((proba.shape[0], -1))\
print proba

我猜现在你不需要调整参数了。:)