变分自动编码器中的潜在损失淹没了生成损失

数据挖掘 深度学习 自动编码器 生成模型 小批量梯度下降
2021-09-27 01:44:51

我正在尝试在 CIFAR-10 数据集上运行变分自动编码器,为此我在 TensorFlow 中组合了一个简单的网络,编码器和解码器各有 4 层,编码向量大小为 256。用于计算潜在损失,我强制网络的编码器部分输出对数方差而不是标准偏差,因此潜在损失函数如下所示:

latent_loss = -0.5 * tf.reduce_sum(1 + log_var_vector - tf.square(mean_vector) - tf.exp(log_var_vector), axis=1)

我发现这个公式比直接使用 KL 散度公式中的对数更稳定,因为后者通常会导致无限的损失值。我在解码器的最后一层应用了 sigmoid 激活函数,生成损失是使用均方误差计算的。组合损失是潜在损失和生成损失的简单总和。我使用 Adam Optimizer 以 0.001 的学习率为 40 个批次训练网络。

问题是我的网络没有训练。潜在损失立即降至零,而生成损失并没有下降。但是,当我仅针对生成损失进行优化时,损失确实会按预期减少。在此设置下,潜在损失的值迅速跃升至非常大的值(10e4 - 10e6 的顺序)。

我有一种预感,罪魁祸首是两种损失的幅度之间的极端不匹配。KL-divergence 是无界的,而均方误差始终保持 <1,因此当对两者进行优化时,生成损失基本上变得无关紧要。

欢迎任何解决问题的建议。

2个回答

我不喜欢reduce_sumkl-loss 的版本,因为它取决于潜在向量的大小。我的建议是改用平均值。

此外,众所周知的事实是,训练具有 kl 损失的 VAE 是很困难的。您可能需要逐步增加 kl 损失在您的总损失中的贡献。添加一个w_kl控制贡献的权重:

Loss = recons_loss + w_kl * kl_loss

您从每个时期(或批次)开始w_kl=0并逐渐将其增加到 1。这是一个经典的技巧。你的学习率看起来不错,也许你可以尝试更高一点(4e-4)。

如果您不喜欢这些技巧,Wasserstein 自动编码器可能是您的朋友。

我觉得你的预感是对的。生成损失无法改善,因为网络为减少它所做的任何运动都会以潜在损失的形式带来巨大的损失。看起来你正在通过 sigmoid 压缩生成损失,也许尝试对潜在损失做同样的事情?