具有单位高斯的 KL 损失

机器算法验证 推理 kullback-leibler 自动编码器 变分贝叶斯
2022-03-08 08:01:20

我一直在实现 VAE,并且我注意到简化的单变量高斯 KL 散度的两种不同的在线实现。这里的原始分歧 如果我们假设我们的先验是单位高斯,即,这将简化为 这就是我的困惑所在。虽然我发现了一些具有上述实现的晦涩的 github 存储库,但我发现更常用的是:

KLloss=log(σ2σ1)+σ12+(μ1μ2)22σ2212
μ2=0σ2=1
KLloss=log(σ1)+σ12+μ12212
KLloss=12(2log(σ1)σ12μ12+1)

=12(log(σ1)σ1μ12+1)
例如在官方的Keras 自动编码器教程中。那么我的问题是,我在这两者之间缺少什么?主要区别是在对数项上删除因子 2 而不是平方方差。从分析上讲,我已经成功地使用了后者,因为它的价值。提前感谢您的帮助!

2个回答

请注意,通过在最后一个等式中将 \sigma_1 替换\,您可以恢复前一个(即)。让我想到,在第一种情况下,编码器用于预测方差,而在第二种情况下,它用于预测标准偏差。σ1σ12log(σ1)σ12log(σ1)σ12

两种表述是等价的,目标不变。

我相信答案更简单。在 VAE 中,人们通常使用多元正态分布,它具有协方差矩阵而不是方差这在一段代码中看起来令人困惑,但具有所需的形式。Σσ2

在这里,您可以找到多元正态分布的 KL 散度推导:Deriving the KL散度损失的 VAE