有人可以向我解释一下 GAN 中的生成器是如何进行反向支持的吗?

数据挖掘 机器学习 神经网络 反向传播
2022-03-13 22:24:12

我对神经网络不是很熟悉,但是,尽管我将反向传播的概念理解为从输出层中的错误开始。比如说,我们在输出层有 3 个神经元,它们各自的值最终是:

[1 0.5 0.3]

我们希望获得价值

[0 1 0]

所以我们可以计算两者之间的误差向量:

[-1, +0.5, -0.3] (Not necessarily with the - operation, but you get the point)

并从那里反向传播。然而,在 GAN 的生成器中,在我看来,输出层有一堆神经元(代表我们想要生成的实体的大小),但错误仅基于判别器分类的图像的百分比错误地。那么我们究竟如何为生成器做反向支持呢?

我发现的唯一人类可读的 GAN 示例使用像 autoGrad 这样的框架和库,这对我来说会混淆问题:/

有没有一种简单的方法来解释这个反向传播是如何完成的?(例如,在代码中,不是“直观的例子”或“简单的 4 个变量方程,其中 4 个变量中的每一个都包含 50 年的数学知识”)

1个回答

生成对抗网络 (GAN) 中的反向传播算法并没有什么特别之处。它与卷积神经网络 (CNN) 相同,因为 CNN 通常是 GAN 的生成器和判别器的组成部分。我将假设 MNIST 玩具示例进行解释,并提供代码以使 GAN 在下面工作。

GAN

GAN 由鉴别器和生成器组成。在训练另一个时,这些中的每一个都保持不变。因此,我们将交替训练鉴别器和生成器。这是单独进行的。训练鉴别器要容易得多,所以让我们看一下,然后我们将按照您在问题中提出的问题来训练生成器。

在此处输入图像描述

训练判别器

鉴别器有两个输出节点,用于区分真实实例和人工实例。为了训练判别器,我们将生成m使用来自生成器的前向传递的实例,这些是人工实例,它们的标签将是y=0. 为了生成这些,我们只需将噪声向量作为模型的输入传递。我们还将使用m来自真实数据的实例,这些实例将带有标签y=1.

在此处输入图像描述

然后我们可以看到鉴别器的训练方式与具有 2 个输出节点的基本分类 CNN 完全相同。我将在下面描述一个 CNN 的训练过程。

训练生成器

当我们训练生成器时,我们将保持鉴别器固定,这是为了不使我们的参数饱和并使鉴别器过于强大而无法击败是必要的。所以我们本质上拥有的是一个连接到另一个 CNN(鉴别器)的 CNN(生成器)。这两个模型之间的连接节点将被输出,一旦训练,它将生成所需的图像。请注意,生成器希望它将在这种情况下生成的实例被鉴别器分类为来自真实分布,因此它们将具有标签y=1.

在此处输入图像描述

总之,这只是一个 CNN,反向传播将以完全相同的方式计算。首先我们通过一个噪声向量,它通过生成器,在它的输出端生成一些随机图像,然后通过鉴别器并被分类为artificial. 但是,我们预计在这种情况下鉴别器会被愚弄,所以这是一个错误,它应该被标记为real. 然后我们使用反向传播来获得每个模型参数的误差贡献。然后我们将使用梯度下降来更新与生成器相关的所有参数。


反向传播

这是一种用于计算每个参数对误差项的贡献的方法。然后我们使用梯度下降来更新这些参数,以便下一次通过应该导致更低的损失率。选择正确的损失函数对于这个过程至关重要。对于分类任务,就像 GAN 一样,我们通常选择由下式定义的二元交叉熵

L=ylog(y^)(1y)log(1y^)

及以上N与随机梯度下降的典型情况一样,损失函数为

L(w)=1Nn=1N[ynlog(y^n)+(1yn)log(1y^n)]

在哪里y是真正的标签并且y^是预测的标签。

深度神经网络中的反向传播

看看这里的这个答案,它描述了使用反向传播和梯度下降来训练单个神经元感知器,然后是多层网络的过程。

唯一的区别是我们在这里使用二元熵损失函数,它具有不同的导数y^. 这变成

Ly^=1Nn=1N[yy^1y1y^]

然后,您将使用链式规则和您在每一层选择的激活函数,通过网络反向传播此损失。

CNN中的反向传播

有关详细信息和 CNN 反向传播的推导,请参阅此答案。该过程与深度神经网络的过程非常相似。但是CNN在每一层都使用了互相关函数,所以需要通过这个函数的导数来反向传播损失函数。这个问题询问了具有两个输出的 CNN 的内部工作,非常类似于我们的鉴别器。

最后,为了训练生成器,想象一下前面有更多卷积层的相同过程。当我们有相关的梯度L(w)对于每个参数,仅将梯度下降算法应用于作为生成器模型一部分的参数。我相信通过这些推导你会看到层之间的中间节点的数量不会影响反向传播算法。它保持相同的过程。您还应该确信,对于非人类来说,作为 GAN 生成图像的中间节点与模型中的任何中间节点没有什么不同。我们简单地训练这些中间节点,使我们可以感知一些含义,例如生成 MNIST 数据集的实例。