为什么令牌的嵌入乘以D--√D(注意不除以 D 的平方根)在变压器中?

机器算法验证 机器学习 神经网络 自然语言 变压器
2022-03-13 17:03:52

为什么 PyTorch 中的转换器教程有乘以 sqrt 的输入数?我知道在多头自注意力中有一个除以 sqrt(D),但是为什么与编码器的输出有类似的东西呢?特别是因为原始论文似乎没有提到它。

特别是(https://pytorch.org/tutorials/beginner/translation_transformer.html):

src = self.encoder(src) * math.sqrt(self.ninp)

或者这个(https://pytorch.org/tutorials/beginner/transformer_tutorial.html):

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

请注意,我知道注意力层有这个等式:

α=Attention(Q,K,V)=SoftMax(QKD)V

他们在论文的一个空白处争论为什么会这样(关于方差总和为 1)。

这与该评论有关吗?它是如何相关的?这在原始论文中提到了吗?

交叉贴:

2个回答

我们相乘是因为我们使用学习嵌入将输入标记和输出标记转换为d_model嵌入层中的维度向量。参考Section 3.4论文Embeddings and Softmax您还会看到用于源和目标令牌嵌入对象。TokenEmbeddingSeq2SeqTransformer

请注意,词嵌入中的权重被初始化为零均值和单位方差。此外,嵌入在添加到位置编码sqrt(d) 之前会被乘以。位置编码也与嵌入在同一尺度上。

我的假设是,作者尝试用不同的数字重新缩放嵌入(正如他们当然用注意力所做的那样),并且这种特殊的重新缩放恰好起作用,因为它使嵌入比位置编码大得多(最初)。位置编码是必要的,但它们可能不应该像单词本身那样“响亮”。