XGBoost 用于多标签图像分类

数据挖掘 xgboost 多标签分类
2022-03-08 16:41:41

我正在尝试将 xgboost 分类器用于多标签和多类图像分类任务。我有一个图像列表,每个图像最多可以有 5 个不同的标签。在我使用分类器之前,我还想应用图像增强。

import keras
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator
from xgboost.sklearn import XGBClassifier

train_idx, val_idx = train_test_split(mask_df.index,  test_size=0.2,random_state=28)

train_datagen=ImageDataGenerator(zoom_range=0.1,
                              fill_mode='constant',
                              rotation_range=10,
                              height_shift_range=0.1,
                              width_shift_range=0.1,
                              horizontal_flip=True,
                              vertical_flip=True,
                              rescale=1/255.)

train_generator=train_datagen.flow_from_dataframe(
                dataframe=mask_df.loc[train_idx],
                directory="home/DATA/train_images/",
                x_col="ImageId",
                y_col=columns,
                color_mode='grayscale',
                batch_size=32,
                seed=32,
                shuffle=True,
                class_mode="other",
                target_size=(100,100)) 

model = XGBClassifier()
history=model.fit_generator(generator=train_generator,
                steps_per_epoch=100,
                validation_data=validation_generator,
                validation_steps=100,
                epochs=5)

最后一个命令给了我一个错误:

AttributeError                            Traceback (most recent call last)
<ipython-input-8-8c4c0504d559> in <module>
----> 1 history=model.fit_generator(generator=train_generator,
      2                     steps_per_epoch=100,
      3                     validation_data=validation_generator,
      4                     validation_steps=100,
      5                     epochs=5

AttributeError: 'XGBClassifier' object has no attribute 'fit_generator'

由于我无法使用 fit_generator,是否有人对如何进行有任何建议?

1个回答

我希望从配合中移除发电机会起作用。

history=model.fit(generator=train_generator,
                steps_per_epoch=100,
                validation_data=validation_generator,
                validation_steps=100,
                epochs=5)