如何在 keras 中使用带有 one-hot 编码的 class_weights?

数据挖掘 机器学习 神经网络 喀拉斯
2022-02-09 09:30:41

是否可以使用class_weightsone-hot 编码?

我已经尝试过sparse_categorical_crossentropy了,由于某种原因,它比我的经典categorical_crossentropy热编码模型差得多。

这就是我用稀疏计算 class_weights 的方式:

unique_class_weights = np.unique(labels)
class_weights = class_weight.compute_class_weight('balanced', unique_class_weights, labels)
class_weights_dict = { unique_class_weights[i]: w for i,w in enumerate(class_weights) }

训练如:

full_model.compile(loss='sparse_categorical_crossentropy',
                   optimizer='rmsprop',
                   metrics=['accuracy'])


full_model.fit_generator(line_generator(data_train, labels_train),
                         validation_data=line_generator(data_test, labels_test),
                         validation_steps=1,
                         steps_per_epoch=len(data)/GENERATOR_BATCH_SIZE,
                         class_weight=class_weights_dict,
                         epochs=1)
1个回答

我不允许发表评论,但是您是否尝试过使用从 class_weight.compute_class_weight() 获得的 numpy 数组,而不是将其转换为字典?我一直跳过那部分,在你的情况下,我会说 class_weight=class_weights。对不起,如果我建议你已经排除的东西。祝你好运。