我的分类器还需要多少数据?

数据挖掘 scikit-学习 多类分类
2022-02-28 14:21:45

我有一个大约 10 万篇新闻文章的数据集,我正在尝试根据每篇文章的标题和线索构建一个分类器。数据集没有预先标记,因此我手动标记要训练的文章子集。到目前为止,我有 9 个类别,每个类别有大约 600 篇带标签的文章。我正在使用doc2vec创建文档向量和SVC分类器sklearn来进行预测。

我的交叉验证分数(有 10 次拆分和数据洗牌)徘徊在 0.89 左右。

当我为训练集和测试集绘制学习曲线时,我将其解释为具有高方差的分类器,并且我需要收集更多数据。但是有什么方法可以估算出需要多少数据才能获得例如 0.95 的交叉验证分数?

这是我的学习曲线和平均分数:

在此处输入图像描述

训练分数:1.0、0.99919679、0.9991984、0.99791667、0.99615385、0.99550155、0.99487179。

测试分数:0.40769912、0.72283119、0.78529511、0.83461723、0.85527486、0.86151579、0.86471912。

编辑:我对 SVC 模型的 C 和 gamma 参数进行了网格搜索,并将 gamma 调整为 0.015,这使得两条线更加收敛。添加情节和新分数。

调整游戏后的学习曲线

新的训练分数:0.97758621、0.98703072、0.98177172、0.97540872、0.96717739、0.96479244、0.96241702。

新的测试分数:0.45800254、0.78917394、0.83526468、0.853326、0.86849859、0.87232875、0.87508134。

1个回答

我想你已经有足够的数据了。您的问题似乎是泛化问题之一,这通常是一个难以解决的问题。对于 SVM,根据您拥有的类型,您可以尝试修改参数和内核。看看这里。-- https://stats.stackexchange.com/questions/35276/svm-overfitting-curse-of-dimensionality?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa