我们是否还需要为 VAE 的解码器建模概率分布?

人工智能 计算机视觉 生成模型 自动编码器 潜变量
2021-10-30 01:41:45

我正在努力理解 VAE,主要是通过斯坦福 cs231n 的视频讲座,特别是关于这个主题的第 13 讲讲课,我认为我对理论有很好的把握。

但是,当查看实际的实现代码时,例如 来自VAEs博客的这段代码,我看到了一些我不太理解的差异。

请从课堂上看一下这个 VAE架构可视化,特别是解码器部分。从这里呈现的方式,我了解到解码器网络输出数据分布的均值和协方差。为了获得实际输出(即图像),我们需要从通过均值和协方差参数化的分布中进行采样——解码器的输出。

现在,如果您查看 Keras 博客 VAE 实现中的代码,您会发现没有这样的东西。解码器从潜在空间中获取样本,并将其输入(采样的 z)直接映射到输出(例如图像),而不是映射到要从中采样输出的分布参数。

我是否遗漏了什么,或者这个实现与讲座中介绍的不对应?一段时间以来,我一直试图理解它,但似乎仍然无法理解其中的差异。

2个回答

感谢@nbro 指出这一点。

幻灯片中的图形架构使用高斯损失,当与最大似然估计结合使用时,会给出平方误差损失(不消除任何易处理性问题)。我们使用编码器高斯技巧的主要原因是强制潜在变量z正常以便我们可以申请KL Divergence优化一个原本难以处理的积分。您可以在此视频中获得更好的直觉和推理

图形架构基本上采用高斯损失,使最终损失有效地变为平方误差损失。此外,您的博客链接中使用的损失术语与原始论文中使用的损失术语完全相同,但博客使用的是 CE 损失(这是用于分类的更常见的损失)。我不确定他们是如何使用 CE 损失的,因为它仅适用于01值和 AFAIK MNIST 数据集具有灰度图像。

我不确定他们如何在解码器结构中实现高斯损失的随机性,但在最简单的情况下,他们只采用 MSE

看看这个关于VAE 的博客(他们在哪里取平均值Σ他们将其缩写为平均值,我没有检查他们的实现细节以了解它们的确切含义)以及关于 VAE 4实现的数据科学的答案(两者都给出了更一般的损失形式)。此外,对于精确的数学,请查看原始论文的附录 C。

cs231n 类的 VAE 架构只是 Keras 提供的代码的更通用版本,其中协方差矩阵为0. 您可以从重新参数化技巧中看到这一点

x=μ+Σϵ=μif Σ=0