为什么我们需要 Gumbel-Softmax 技巧中的温度?

机器算法验证 自动编码器 口香糖分布
2022-04-05 11:17:52

假设离散变量具有非归一化概率,一种采样方法是应用 argmax(softmax( )),另一种是使用 Gumbel 技巧 argmax( ),其中是 gumbel 产生的噪声。如果我们想做变分自动编码(即将输入编码为潜在离散变量),则第二种方法很有用。然后,如果目标是对的可能结果进行完整分布,我们可以在 Gumbel 噪声扰动之上使用 softmax 变换: zjαjαjlogαj+gjgjxjzjzj

πj=elogαj+gjk=1k=Kelogαk+gk   where  gk=log(log(ϵU(0,1))).
为什么这还不够?为什么我们需要在其中包含温度项?并重写, 我知道温度使矢量更平滑或更粗糙(即高温只会使所有相同,并生成更平坦的分布,并且τ
πj=elogαj+gjτk=1k=Kelogαk+gkτ   where  gk=log(log(ϵU(0,1)))
π=[π1,...,πk]πiτ=1只是使两个方程相同)但为什么我们在实践中需要它?我们想要的(即,在 VAE 中)是解耦采样的随机方面(即,将其随机部分移动到输入),这是通过 Gumbel 技巧实现的,然后以某种方式将 one-hot vector draw 替换为一个连续向量,我们通过使用第一个方程得到的 softmax(我确定我错过了一些基本的东西,但看不到它是什么......logαj+gj

1个回答

一种采样方法是应用 argmax(softmax( ))αj

这几乎不是“抽样”,因为您每次都确定性地选择最大的(另外,你说是未归一化的概率,但是当对数概率进入 softmax 时,这是没有意义的)。正确的采样方法是 sample(softmax( )),其中是 logits。实际上,gumbel-softmax 的目标不是替换您编写的 softmax 操作,而是替换采样操作:αjαxx

我们可以 ),其中是概率向量,其中是 gumbel 噪声。当然,这等价于 argmax( ),其中又是 logits。总而言之,sample(softmax( )) 和 argmax(是等价的过程。pplogp+ggx+gxxx+g)

可能结果的完整分布,我们可以在 Gumbel 噪声扰动之上使用 softmax 变换。zj

事实上,您已经对所有可能的结果进行了分布。

然而,argmax(是不可微的,因此为了反向传播,我们将其梯度替换为 softmax( ) 的梯度。时,表达式接近 argmax。x+gx(x+g)τ1τ0

选择一个合理的、小的值将确保对梯度的良好估计,同时确保梯度在数值上表现良好。τ

只是使两个方程相同τ=1

实际上,并没有什么特别的意义。相反,使梯度估计无偏但方差高,其中的较大值会为梯度估计增加更多偏差但降低方差。τ=1τ0τ