我没有使用ImageDataGenerator,因为我使用的是 hdf5 文件。我使用 DataGenerator 类 (1) 将数据提供给model.fit_generator. 鉴于我有一批随机数据,如何获得混淆矩阵?我知道如何model.predict使用 sklearn 来获得它。
(1) https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly#
我没有使用ImageDataGenerator,因为我使用的是 hdf5 文件。我使用 DataGenerator 类 (1) 将数据提供给model.fit_generator. 鉴于我有一批随机数据,如何获得混淆矩阵?我知道如何model.predict使用 sklearn 来获得它。
(1) https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly#
loaded_model.predict_generator(generator=test_generator)
会给我们一组概率
y_true = test_generator.classes
会给我们真正的标签
因为这是一个二元分类问题,所以你必须找到预测的标签。为此,您可以使用:
y_pred = probabilities > 0.5
然后我们在测试数据集上有真实标签和预测标签。因此,混淆矩阵由下式给出:
font = {
'family': 'Times New Roman',
'size': 12
}
matplotlib.rc('font', **font)
mat = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(conf_mat=mat, figsize=(8, 8), show_normed=False)