ValueError:类的数量必须大于一;在多标签分类问题中获得 1 类

数据挖掘 Python 多标签分类
2022-03-10 16:33:30

我正在使用 Python 解决多标签分类问题。我有一个包含文本和大约 20k 个唯一标签的数据集。我将文本转换为词嵌入,现在我在 ChainClassifier 中使用它来预测每个文本可能具有的标签。

from skmultilearn.problem_transform import LabelPowerset, ClassifierChain, BinaryRelevance
from sklearn.svm import SVC

df = pd.read_csv('sample_fos.csv')
df = df.dropna()
df.fos = df.fos.str.split(',')

mlb = MultiLabelBinarizer()

data = df.join(pd.DataFrame(mlb.fit_transform(df.pop('fos')),
                          columns=mlb.classes_,
                          index=df.index))

labels = mlb.classes_

X_train, X_test, y_train, y_test = train_test_split(data[features], data[labels], test_size=0.20, shuffle=True)

clf = ClassifierChain(classifier=SVC(gamma="auto"))
clf.fit(X_train,y_train)
y_pred = clf.predict(X_test)

我收到以下错误:

  File "<ipython-input-54-976064cc6433>", line 30, in <module>
    clf.fit(X_train,y_train)

  File "/lib/python3.7/site-packages/skmultilearn/problem_transform/cc.py", line 155, in fit
    X_extended), self._ensure_output_format(y_subset))

  File "/lib/python3.7/site-packages/sklearn/svm/base.py", line 147, in fit
    y = self._validate_targets(y)

  File "/lib/python3.7/site-packages/sklearn/svm/base.py", line 521, in _validate_targets
    " class" % len(cls))

ValueError: The number of classes has to be greater than one; got 1 class

据我所知,这意味着其中一个标签列只有一个类(0或者1

但是,运行以下代码段,我发现我的所有列都有 2 个唯一值

b = []
for i in labels:
    b.append(len(data[i].unique()))
print(min(b),max(b))
>> 2, 2
1个回答

你能看到有多少课程在训练和测试中吗?

y_train.unique()
y_test.unique()

有时,您可能会得到一个只有一个标签的拆分,如果您的数据很小或严重不平衡,则有更多机会。解决它的一种方法是对目标进行分层。

train_test_split(data[features], data[labels], test_size=0.20, shuffle=True,stratify=data[labels])