Keras Attention Guided CNN 问题

数据挖掘 喀拉斯 张量流 美国有线电视新闻网 注意机制
2022-03-08 01:36:26

我正在研究用于 XRay 图像分类的 CNN,但我似乎无法正确训练它。我正在尝试在 Keras 中实现以下论文:https ://arxiv.org/pdf/1801.09927.pdf

简而言之,论文描述了3个网络架构。 注意力引导CNN

Global Branch 是在 imagenet 上预训练的 ResNet 或 DenseNet(我使用 DenseNet),然后在 CheXNet 数据集上进行训练。训练完成后,来自最后一个卷积层的激活用于创建热图,然后用于裁剪原始数据集中的图像并创建新数据集。

本地分支是没有预训练的 ResNet 或 DenseNet(我使用了 DenseNet)。它在裁剪的数据集上进行训练。

最后,融合分支将全局和局部分支的 2 个全局池化层作为输入。

我已经成功地训练了全球分支(AUC:0.77),生成作物并训练了本地分支(AUC:0.67)。但是当我尝试训练融合分支时,val_loss 并没有减少: val_loss 表

这是我的代码:

def load_model_from_json(models_folder):
    print(" --- Reading model from ", models_folder)
    with open(models_folder + 'model.json', 'r') as f:
        json = f.read()
    print("Read: ", models_folder + 'model.json')
    model = model_from_json(json)

    model.load_weights(models_folder + "model_weights.h5")
    print("Read: ", models_folder + "model_weights.h5")

    return model

global_branch_model = load_model_from_json(self.global_branch_path)

local_branch_model = load_model_from_json(self.local_branch_path)

for l in global_branch_model.layers:
    l.trainable = False
    l.name = 'global_'+l.name
for l in local_branch_model.layers:
    l.name = 'local_'+l.name
    l.trainable = False

global_pooling = global_branch_model.get_layer('global_global_average_pooling2d_1')

local_pooling = local_branch_model.get_layer('local_global_average_pooling2d_1')

merged = concatenate([global_pooling.output, local_pooling.output])
dense = Dense(512, activation='relu')(merged)
dropout = Dropout(self.hyperparameters.dropout)(dense)
out = Dense(1, activation='sigmoid')(dropout)

fusion_model = Model(inputs=[global_branch_model.input, local_branch_model.input], outputs=out)
loss_function = unweighted_binary_crossentropy

optimizer = AdamW(lr=5e-5)

fusion_model.compile(optimizer=optimizer, loss=loss_function)

fusion_model.fit_generator(
            generator=FusionDataGenSequence(self.labels, self.partition['train'],
                                      current_state='train',
                                      batch_size=self.batch_size,
                                      hyperparameters=self.hyperparameters,
                                      num_classes=self.number_of_classes),
            epochs=self.hyperparameters.epochs,
            verbose=1,
            callbacks=callbacks,
            workers=self.num_workers,
            # max_queue_size=32,
            # shuffle=False,
            validation_data=FusionDataGenSequence(self.labels,
                                            self.partition['valid'],
                                            current_state='validation',
                                            batch_size=self.batch_size,
                                            hyperparameters=self.hyperparameters,
                                            num_classes=self.number_of_classes)
            # validation_steps=1
        )

你能帮我弄清楚我做错了什么吗?谢谢!

2个回答

也许我遗漏了一些东西,但是全球分支机构和本地分支机构之间的联系在哪里?

在本地分支的训练过程中,您使用了一个由蒙版图像构建的数据集。但是现在该分支应该将来自全局分支的 7x7x2048 卷积层的输出作为输入。

也许您的训练集是由 2 个分支的耦合输入构建的?我认为您应该添加“FusionDataGenSequence”的实现。

您在训练期间的准确性如何?你尝试过不同的学习率吗?

我发现了问题。在 FusionDataGenSequence 函数中,我不小心对用于训练分支的图像使用了不同的归一化函数。