我正在尝试在 CIFAR-10 数据集上运行变分自动编码器,为此我在 TensorFlow 中组合了一个简单的网络,编码器和解码器各有 4 层,编码向量大小为 256。用于计算潜在损失,我强制网络的编码器部分输出对数方差而不是标准偏差,因此潜在损失函数如下所示:
latent_loss = -0.5 * tf.reduce_sum(1 + log_var_vector - tf.square(mean_vector) - tf.exp(log_var_vector), axis=1)
我发现这个公式比直接使用 KL 散度公式中的对数更稳定,因为后者通常会导致无限的损失值。我在解码器的最后一层应用了 sigmoid 激活函数,生成损失是使用均方误差计算的。组合损失是潜在损失和生成损失的简单总和。我使用 Adam Optimizer 以 0.001 的学习率为 40 个批次训练网络。
问题是我的网络没有训练。潜在损失立即降至零,而生成损失并没有下降。但是,当我仅针对生成损失进行优化时,损失确实会按预期减少。在此设置下,潜在损失的值迅速跃升至非常大的值(10e4 - 10e6 的顺序)。
我有一种预感,罪魁祸首是两种损失的幅度之间的极端不匹配。KL-divergence 是无界的,而均方误差始终保持 <1,因此当对两者进行优化时,生成损失基本上变得无关紧要。
欢迎任何解决问题的建议。