对于变分自动编码器,重建损失应该计算为输入的总和还是平均值?

机器算法验证 损失函数 喀拉斯 自动编码器 kullback-leibler 变分贝叶斯
2022-03-25 21:27:44

我正在关注这个变分自动编码器教程:https ://keras.io/examples/generation/vae/ 。我在下面的代码中包含了损失计算部分。

我知道VAE的损失函数包括比较原始图像和重建的重建损失,以及KL损失。但是,我对重建损失以及它是在整个图像(平方差之和)还是每个像素(平方差的平均和)上有点困惑。我的理解是重建损失应该是每像素(MSE),但我遵循的示例代码将 MSE 乘以 28 x 28,即 MNIST 图像尺寸。那是对的吗?此外,我的假设是这会使重建损失项显着大于 KL 损失,我不确定我们是否想要这样。

我尝试通过 (28x28) 删除乘法,但这导致重建极差。无论输入如何,基本上所有的重建看起来都是一样的。我可以使用 lambda 参数来捕获 kl 散度和重建之间的权衡,或者它不正确,因为损失具有精确的推导(而不是仅仅添加正则化惩罚)。

reconstruction_loss = tf.reduce_mean(
    keras.losses.binary_crossentropy(data, reconstruction)
)
reconstruction_loss *= 28 * 28
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_mean(kl_loss)
kl_loss *= -0.5
total_loss = reconstruction_loss + kl_loss
2个回答

要直接找到答案,损失确实有一个精确的推导(但这并不意味着你不一定能改变它)。

重要的是要记住,变分自动编码器的核心是一种对我们假设正在生成数据的一些潜在变量进行变分推断的方法。在这个框架中,我们的目标是最小化潜在变量上的一些近似后验与真实后验之间的 KL 散度,我们也可以选择最大化证据下限 (ELBO),详情请参阅VAE 论文这为我们提供了 VAE 的目标:

L(θ,ϕ)=Eqϕ[logpθ(x|z)]Reconstruction LossDKL(qϕ(z)||p(z))KL Regulariser

现在重建损失是给定潜在变量的数据的预期对数似然。对于由多个像素组成的图像,总对数似然将是所有像素的对数似然之和(假设独立),而不是每个单独像素的平均对数似然,这就是它的原因例子中的情况。

是否可以添加额外参数的问题是一个有趣的问题。例如,DeepMind 引入了 -VAE,它确实做到了这一点,尽管目的略有不同——他们表明,这个额外的参数可以导致更解耦的潜在空间,从而允许更多可解释的变量。这种目标改变的原则性有待商榷,但它确实有效。话虽这么说,很容易通过简单地更改潜在变量的先验()以有原则的方式更改 KL 正则化项,原来的先验是一个非常无聊的标准正态分布,因此只需交换其他内容即可改变损失函数。尽管我自己没有检查过,您甚至可以指定一个新的先验(βp(z)p(z) ) 使得:

DKL(qϕ(z)||p(z))=λDKL(qϕ(z)||p(z)),

这将完全符合您的要求。

所以基本上答案是肯定的 - 如果它可以帮助您完成您想要的任务,请随意更改损失函数,只需了解您所做的与原始案例有何不同,这样您就不会提出任何您不应该的主张吨。

据我了解 VAE 是如何工作的,KL 损失可以被视为正则化器,而重建损失是驱动模型权重以产生正确输出的损失。

要回答您的具体问题:“我可以使用 lambda 参数来捕获 kl 散度和重建之间的权衡”;是的,你可以使用一个参数而不是一个多人游戏,比如但是,必须假设一个较小的值()。reconstructionloss+λ×kllossλ1/282

我发现这篇论文对于掌握 VAE 中的一般概念很有用。 https://arxiv.org/abs/1606.05908