为使用 Keras 中的功能 API 开发的模型预测类的更好方法是什么

数据挖掘 喀拉斯 多标签分类 预测 功能API
2022-02-20 22:07:27

我们可以使用 Keras 中的顺序分类模型使用 predict_classes() 函数来预测新数据实例的类别。预测使用功能 API开发的模型的类的方法是什么?

例如,我有一个模型(基于函数式 API),在最后一层使用 sigmoid 激活来获得多标签分类中的概率。当我应用 model.predict() 时,我得到了一系列概率,即使损失是 binary_crossentropy。

我知道我可以手动进行此分类,例如以下方法。

test_predict_proba = model.predict(x_test, batch_size=batch_size)
class_predict = (test_predicted_proba > 0.5).astype(int)

我想知道是否有任何标准程序可以做到这一点?

1个回答

这是多标签问题还是纯分类问题?如果只是分类,则将最后一层的激活函数更改为 softmax。当您进行预测时,选择概率最高的输出作为类。或者使用 model.evaluate(.....)。与 model.predict 文档在这里