编辑:事实证明,即使是模型的初始创建者也无法成功对其进行微调。这很可能是一个实现问题,或者可能与Keras 批量标准化层工作的非直观方式有关。
我正在尝试在源自非增强Pascal VOC 2012 基准数据集(1449 个训练示例)的自定义数据集上微调谷歌 DeepLab v3+ 模型的 Keras实现,以解决我的研究问题。
我想我首先尝试自己在原始 Pascal VOC 数据集上重新训练它,并尝试获得接近论文的结果。回购的作者显然成功地做到了这一点,因此 Keras 模型不正确的可能性很小。我成功加载了在 ImageNet 上预训练的模型(来自 Google 的官方模型 zoo),特征图清楚地表明该模型能够区分它所输入的图片中的所有对象(见下图)。
我冻结了与初始主干相对应的 356 个第一层(在我的情况下为 Xception)。我在模型中添加了一个额外的最后softmax
一层,因为文章中的模型最初输出 logits。与此选择相关的是,数据集确实有一个背景类。我正在使用tf.keras.optimizers.Adadelta
优化器。
然而,经过数周的调整和探索,我仍然无法让模型学习或在分割方面做任何有价值的事情。
我尝试使用网络上随处可见的数十种不同的损失和准确度函数,主要是逐像素交叉熵和软骰子损失的变体,以及调整学习率 到 (作者使用 在原始论文中),每次我得到的都是一样的;损失值基本上在一个相当小的值附近振荡,然后我正在使用的早期停止回调在之后停止训练阶段 或者 时代。
如果我不停止该过程,这是一个典型指标的演变(在这种情况下,学习率设置为 , 批量大小 ):
我决定在每个 epoch 之后对同一张图像进行预测,这是第一个 epoch 之后的样子(右下角的“标签”图片只是argmax
特征图上的):
在 20 个 epoch 之后:
所有中间结果看起来都像,显然独立于超参数。
我什至尝试使用 repo 的作者所说的准确度和损失函数,他曾经完全按照我的意图去做,但我得到了相同的混乱度量曲线。
我对这可能来自哪里的想法严重不足。我很想得到关于我下一步应该在哪里寻找我可能犯的错误的提示。
数据流详细信息
我正在使用 TensorFlow 的数据集 API(基本上遵循这个非常好的指南)将数据集加载到内存中。所述数据集事先被打乱并分成 的碎片 示例,这是我可以在我的硬件上使用的最大批量大小。然后我选择一组打乱的碎片并通过重新缩放/填充/裁剪它们来预处理其中的示例 强度值介于 和 ,将它们转换为tf.float32
张量并生成数据集的每个类的二进制掩码。
- 输入张量是一批形状与值] 并编码为
float32
; - 相关的基本事实是一个形状张量值是, 或者 (后一个值用于“模糊”或填充区域;依次忽略部分图像)。
损失函数和精度函数
我首先忽略被忽略区域中的标签和预测(参见值多于):
def get_valid_labels_and_logits(y_true, y_pred):
valid_labels_mask = tf.not_equal(y_true, 255.0)
indices_to_keep = tf.where(valid_labels_mask)
valid_labels = tf.gather_nd(params=y_true, indices=indices_to_keep)
valid_logits = tf.gather_nd(params=y_pred, indices=indices_to_keep)
return valid_labels, valid_logits
我在一个小习惯上检查了三次图像,它按预期工作。
接下来,我计算本文定义的所有类的平均骰子损失:
def soft_dice_loss(y_true, y_pred):
y_true, y_pred = get_valid_labels_and_logits(y_true, y_pred)
# Next tensors are of shape (num_batches, num_classes)
interception_volume = tf.reduce_sum(tf.reduce_sum(y_true * y_pred, axis=1), axis=1)
y_true_sum_per_class = tf.reduce_sum(tf.reduce_sum(y_true, axis=1), axis=1)
y_pred_sum_per_class = tf.reduce_sum(tf.reduce_sum(y_pred, axis=1), axis=1)
return tf.reduce_mean(1.0 - 2.0 * interception_volume / (y_true_sum_per_class + y_pred_sum_per_class))
我尝试了这些的不同变体,包括本机交叉熵和二进制交叉熵,但它并没有太大改变行为。
默认的准确度函数似乎不起作用,所以我实现了一个自定义的平均 IoU 准确度函数,该函数在我手动给出的示例中运行良好。
为这堵文字墙道歉,但我想澄清一下情况。非常感谢您的帮助和建议!