GAN - 我看到模式崩溃了吗?常见修复不起作用

数据挖掘 机器学习 神经网络 深度学习 火炬
2021-09-19 23:24:46

我有一个两部分的问题。


上下文 我正在学习 GAN 并从最简单的对抗性学习示例(1 参数节点)开始编写自己的内容,然后实现一个非常简单的 1 维模式(1010)学习 GAN .. 现在我正在尝试实现一个MNIST 学习 GAN,然后再进行更逼真的照片。

我有一些机器学习和数据挖掘的背景(很久以前的大师),并且对神经网络的工作原理有一定的了解。

您可以在此处阅读我在初始步骤中的进展:

我已经阅读了很多最新的博客、文章并观看了 youtube 教程,但仍然无法解决实现 MNIST 学习 GANS 的两个关键问题。


Q 1. 我看到模式崩溃了吗?

经过一些调整和迭代后,我有了一个 GAN,它确实学会了生成看起来可能来自 MNIST 数据集的图像。实际上它们还不是数字,但它们是可识别的笔划,当然不是随机噪音。

您可以在此处查看我的 pytorch 代码的最新迭代:github notebook

当输入随机噪声时,经过训练的 GAN 总是生成相同或极其相似的图像。喂它 1-shot noise (00001000...) 也会生成类似的图像。

GAN 是否应该只生成一张图像?我们看到的不同图像是否来自 GAN 的单独训练?我认为这个想法是一个经过训练的 GAN 可以从随机输入噪声中生成许多不同的图像。我误解了吗?


Q 2. 如何逃脱模式-折叠?

如果上面的答案是训练 GAN 应该输出许多不同但有效的图像,那么我有模式崩溃。

我已经广泛阅读并尝试了许多避免模式崩溃的方法,但都没有奏效:

  • 有/无批量标准化
  • 有/没有最大池化
  • 有/无辍学
  • 带/不带标签软化
  • 在输入和目标标签中添加/不添加噪声,包括随训练时间衰减的噪声
  • 生成器的各种宽度和深度,鉴别器则更少
  • 增加训练时间(但较差的计算能力只允许我在完整数据集上进行大约 6-10 个时期)

我观察到或注意到的是:

  • 首先测试鉴别器的宽度/深度/架构,以确保它在用于 GAN 之前具有学习多类 MNIST 的能力,以避免出现鉴别器实际上无法学习鉴别 MNIST 的情况
  • 绘制训练进行时的误差(D 误差,G 输入上的 D 误差)有助于显示 D 误差接近 1/2,而 G 输入上的 D 误差接近 1(更像是 0.8)。
  • 绘图错误还显示某种稳定性或崩溃,这有助于调整学习率,例如
  • 我本以为向输入或目标标签添加噪声可能会将 GAN 踢出任何模式崩溃的局部最小值,但我怀疑理论比这更复杂
  • 对优化器参数尝试不同的建议没有帮助,我必须找到自己的最佳调整,并且学习率比其他人使用的要低得多,例如 Adam lr=0.00002 而不是 0.001 这会导致不稳定
  • 更高的训练时期会产生高对比度的图像,这些图像看起来不像具有柔和边缘笔划的 MNIST 数据。我希望更高的训练时期能够提高输出的多样性

我找不到太多指导的一个领域是鉴别器和生成器的实际架构:

  • 它们必须匹配但相反吗?我的不是——鉴别器被证明具有学习能力,仅此而已。之后越浅越好,以确保更容易地反向传播到生成器。
  • 反卷积的使用在生成器中很常见,但是网上一些例子使用了简单的全连接映射。在计算上,反卷积具有较少的学习参数,并且直观地在构建图像时有意义。

我欢迎你的想法和建议。

我认为看起来不错但模式折叠的示例输出: 在此处输入图像描述


2个回答

自从这个问题以来,我取得了重大进展,我将其写为该系列的第 3 部分:

http://makeyourownalgorithmicart.blogspot.com/2019/05/generation-adversarial-networks-part-iii.html

虽然我的 GAN 现在没有模式崩溃 - 原因仍不清楚。这篇文章直观地报告了 SGD->Adam、none->LayerNorm、Sigmoid->ReLU 组合的影响,我希望其他人觉得有用:

GAN优化

免责声明:

我对 GAN 也很陌生,但我一直在广泛地玩各种东西,并尝试各种想法来获得可用的东西(我也在使用 PyTorch)。所以我绝不是专家,但在看到你的问题后,我想我会分享一些我在此过程中学到的东西,希望你会发现它们有用。我还没有彻底查看您的代码,所以我假设您的代码通常是正确的(这意味着网络模型、损失计算等没有不必要的愚蠢错误......)。

另外,请注意我没有使用 MNIST,而且我的架构有循环层。因此,并非所有这些建议都适用于您和 YMMV...


Q 1. 我看到模式崩溃了吗?

经过一些调整和迭代后,我有了一个 GAN,它确实学会了生成看起来可能来自 MNIST 数据集的图像实际上它们还不是数字,但它们是可识别的笔划,当然不是随机噪音。

请记住,当您的网络无法生成足够多样化的输出集(大多数/所有样本看起来相同)时,就会发生模式崩溃。

查看您的示例图像,我认为您甚至还没有处于模式崩溃的地步。您的示例图像看起来不太像“真实”数字。根据我的经验,当模式崩溃发生时,您的生成器会生成有效几乎令人信服的示例,但仅此而已。所有生成的示例看起来都差不多。

话虽如此,我认为您的问题是您的生成器网络尚未学会生成真实(或半真实)外观的样本。我建议训练更长时间,同时确保你的损失计算等都是正确的。当我的网络开始工作时,我可以立即看出它正在产生有效的输出。

几点建议:

  • 确保你的学习率足够小。我的第一个问题是学习率过大(我对 Adam 使用了 0.001,然后意识到我的模型只适用于像 0.0002 这样的小东西)。
  • 确保学习正在发生。随着时间的推移跟踪损失值并确保它们有意义。损失值不应该有任何峰值,否则会出现问题。

我在将特征匹配作为损失指标方面取得了巨大成功。如果没有特征匹配,我的网络就从来没有真正运作良好。


Q 2. 如何逃脱模式-折叠?

那是百万美元的问题!在过去的一个月里,我一直在努力寻找模式崩溃的有效解决方案。

你看,我现在的问题是所有生成的样本看起来都非常有说服力,但它们看起来几乎完全相同

如果上面的答案是训练 GAN 应该输出许多不同但有效的图像,那么我有模式崩溃。

我已经广泛阅读并尝试了许多避免模式崩溃的方法,但都没有奏效:

有/没有批量归一化 有/没有最大池化 有/没有辍学 有/没有标签软化 有/没有噪声添加到输入和目标标签,包括随着训练时间衰减的噪声 生成器的不同宽度和深度,因此判别器增加训练时间(但糟糕的计算能力只允许我在完整数据集上进行大约 6-10 个时期)

很高兴您尝试了所有这些(我也尝试了),但我发现对我来说,没有任何架构更改真的有任何区别。到目前为止,唯一对我有用的是使用 WGAN 损失。让它工作起来非常棘手,但在大约 2000 个 epoch 之后,我看到的样本看起来很真实(有一些明显的缺陷),但看起来却彼此非常不同。

我建议你同时尝试 WGAN 和 WGAN-GP 作为你的损失指标。由于这个讨论之外的原因,我不能轻易使用 WGAN-GP。


最后:

它们必须匹配但相反吗?我的不是——鉴别器被证明具有学习能力,仅此而已。之后越浅越好,以确保更容易地反向传播到生成器。

我的鉴别器和生成器几乎匹配并且彼此相反。我尝试了很多变化,但没有发现任何显着的差异/改进。我只是决定保持简单并暂时匹配它们。一旦我的模式崩溃问题消失,我将重新审视这一点。

希望这可以帮助。我真的希望得到一些专家的反馈。