是否可以使用class_weights
one-hot 编码?
我已经尝试过sparse_categorical_crossentropy
了,由于某种原因,它比我的经典categorical_crossentropy
热编码模型差得多。
这就是我用稀疏计算 class_weights 的方式:
unique_class_weights = np.unique(labels)
class_weights = class_weight.compute_class_weight('balanced', unique_class_weights, labels)
class_weights_dict = { unique_class_weights[i]: w for i,w in enumerate(class_weights) }
训练如:
full_model.compile(loss='sparse_categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
full_model.fit_generator(line_generator(data_train, labels_train),
validation_data=line_generator(data_test, labels_test),
validation_steps=1,
steps_per_epoch=len(data)/GENERATOR_BATCH_SIZE,
class_weight=class_weights_dict,
epochs=1)