我也一直在寻找这个问题的答案,我给出了我对 Gumbel softmax 的不同看法,只是因为我认为这是一个很好的问题。
从一般的角度来看:我们正常使用softmax是因为我们需要一个所谓的score,或者distribution π1..πn用于表示大小为 n 的分类变量的 n 个概率;我们使用 Gumbel-softmax 从这个分布中采样一个热样本[0..1..0]。
在更具体的示例中:通常在 NLP 网络中(将输出分类为不同的单词标记),softmax 用于计算当前文本位置的不同的分布,例如 5000 个单词选择。交叉熵损失,给出了关于softmax预测分布和真实词分布之间差异的度量;对于 Gumbel-softmax,它通常用于生成样本 one-hot 向量以构建以下网络,就像在一些基于 VAE 的模型中一样。这就是为什么温度因素τ 对于 Gumbel-softmax 来说是必须的,在大多数情况下在 softmax 中不需要。
至此,我们知道了它们用例的区别,现在是令人困惑的部分:两个公式如此相似,为什么它们做不同的事情?
第一个关键因素:不同之处在于giGumbel-softmax 公式中的术语。它表示从分布中采样的一个点Gumbel(0,1). 添加项log(πi) 和缩放项 τ 只是用来重新参数化它 Gumbel(log(πi),τ)(重新参数化技巧是为了使其可微)。这个 Gumbel 分布是在Gumbel-max方法中对分类变量进行采样的关键分布(因为存在 argmax,所以很难且不可微),这也是 Gumbel-softmax 中 Gumbel 名称的来源。每次我们使用 Gumbel-softmax 时,我们都需要从Gumbel(0,1) 并进行重新参数化技巧,这是与 softmax 最不同的部分。
我宁愿将它命名为soft-Gumbel-max,以表明它有动机制作 Gumbel-max 的软版本,而不是仅仅打算在 softmax 中添加 Gumbel 术语。
第二个最显着的区别:是使用τ. 在大多数神经网络中,softmax 不与这个术语耦合。因为我们通常需要一个分布,而不是一个近乎单一的向量。更重要的是,在某些情况下,比如束搜索,我们需要获得第二或第三个最可能的选择来探索全局最优搜索。在softmax中,τ通常会添加一些特定领域的知识以使分布更陡峭。更小τ并不总是意味着更好,我们需要调整它以最适合模型。然而对于 Gumbel-softmax,在模型训练的后期,我们需要 Gumbel-softmax 尽可能接近 one-hot 向量。这就是为什么我们需要在训练期间将其退火得越来越小。(τ在训练开始时不会太小,因为这会使训练更稳定)。在某些实现中,如torch.nn.functional.gumbel_softmax
,它使用直通技巧hard - (detached soft) + soft
将输出值保持为硬 Gumbel-max 中的单热向量,但具有明确定义的梯度,如软可微 Gumbel-softmax 中。
我不同意前面的答案的大部分是,Gumbel-softmax不会给你确切的 one-hot 向量。它仍然是一个热向量的估计。即使是直通,前向传播也是 Gumbel-max,只有后向传播是 Gumbel-softmax。
这几乎就是我对 Gumbel-softmax、Gumbel-max 和 softmax 的理解。如果有任何不清楚或不正确的地方,请发表评论。
MORE 如果你仍然困惑
写完这个答案后,我发现,虽然softmax和Gumbel-softmax的用例不同,但是我们仍然可以在应用Gumbel-softmax的地方强制应用softmax,而不会遇到任何算术问题,反之亦然。因为它们都是软的,不精确的,而且它们都是可微的。为了更清楚为什么这是一个问题。让我们举两个非常严重的滥用例子。
假设有一个网络需要一个样本而不是一个分类变量的分布,例如,一个网络需要一个分类变量来表示我想生成谁的演讲:奥巴马或特朗普。假设训练数据有 70% 的机会来自特朗普,30% 的时间来自奥巴马,那么 70% 的时间变量应该是 [0 1] 而 30% 的时间应该是 [1 0]。如果我们在这里改用softmax,变量将始终约为[0.3 0.7],因此它变成一个常数,因此可以被网络忽略。为了减轻训练损失损失,生成的网络很可能会产生混合声音,有点像奥巴马,有点像特朗普。更何况训练时的[0.7 0.3]在预测时从来不用,这也是域外预测的问题。
另一个例子是预测句子中“you”之后的下一个单词。假设只有两个选项:“are”和“have”。如果我们用 Gumbel-softmax 代替 softmax 来表示两个词的概率。只有在训练过程进行正确采样时,交叉熵损失才不会接近无穷大(-log0)。在这种情况下,训练过程很容易出现梯度爆炸。