我正在研究用于 XRay 图像分类的 CNN,但我似乎无法正确训练它。我正在尝试在 Keras 中实现以下论文:https ://arxiv.org/pdf/1801.09927.pdf
Global Branch 是在 imagenet 上预训练的 ResNet 或 DenseNet(我使用 DenseNet),然后在 CheXNet 数据集上进行训练。训练完成后,来自最后一个卷积层的激活用于创建热图,然后用于裁剪原始数据集中的图像并创建新数据集。
本地分支是没有预训练的 ResNet 或 DenseNet(我使用了 DenseNet)。它在裁剪的数据集上进行训练。
最后,融合分支将全局和局部分支的 2 个全局池化层作为输入。
我已经成功地训练了全球分支(AUC:0.77),生成作物并训练了本地分支(AUC:0.67)。但是当我尝试训练融合分支时,val_loss 并没有减少:
这是我的代码:
def load_model_from_json(models_folder):
print(" --- Reading model from ", models_folder)
with open(models_folder + 'model.json', 'r') as f:
json = f.read()
print("Read: ", models_folder + 'model.json')
model = model_from_json(json)
model.load_weights(models_folder + "model_weights.h5")
print("Read: ", models_folder + "model_weights.h5")
return model
global_branch_model = load_model_from_json(self.global_branch_path)
local_branch_model = load_model_from_json(self.local_branch_path)
for l in global_branch_model.layers:
l.trainable = False
l.name = 'global_'+l.name
for l in local_branch_model.layers:
l.name = 'local_'+l.name
l.trainable = False
global_pooling = global_branch_model.get_layer('global_global_average_pooling2d_1')
local_pooling = local_branch_model.get_layer('local_global_average_pooling2d_1')
merged = concatenate([global_pooling.output, local_pooling.output])
dense = Dense(512, activation='relu')(merged)
dropout = Dropout(self.hyperparameters.dropout)(dense)
out = Dense(1, activation='sigmoid')(dropout)
fusion_model = Model(inputs=[global_branch_model.input, local_branch_model.input], outputs=out)
loss_function = unweighted_binary_crossentropy
optimizer = AdamW(lr=5e-5)
fusion_model.compile(optimizer=optimizer, loss=loss_function)
fusion_model.fit_generator(
generator=FusionDataGenSequence(self.labels, self.partition['train'],
current_state='train',
batch_size=self.batch_size,
hyperparameters=self.hyperparameters,
num_classes=self.number_of_classes),
epochs=self.hyperparameters.epochs,
verbose=1,
callbacks=callbacks,
workers=self.num_workers,
# max_queue_size=32,
# shuffle=False,
validation_data=FusionDataGenSequence(self.labels,
self.partition['valid'],
current_state='validation',
batch_size=self.batch_size,
hyperparameters=self.hyperparameters,
num_classes=self.number_of_classes)
# validation_steps=1
)
你能帮我弄清楚我做错了什么吗?谢谢!