分裂成多头——多头自注意力

数据挖掘 张量流 伯特 变压器 注意机制
2022-02-20 08:03:20

所以,我对Attention 是你所需要的一切有疑问:

tensorflow官方文档中transformers的实现

每个多头注意力块获得三个输入;Q(查询),K(键),V(值)。这些通过线性(密集)层并分成多个头。

但是,该论文提到:

与使用 dmodel 维度的键、值和查询执行单个注意函数不同,我们发现将查询、键和值分别线性投影到 dk、dk 和 dv 维度上的不同学习线性投影是有益的。然后,在每个查询、键和值的投影版本上,我们并行执行注意功能,产生 dv 维输出值。

没有提到拆分Q、K 和 V 以获得正面。相反,论文说它们通过“h”个不同的密集层,以将 d-model 维向量分别转换为“h”个不同的 dk、dk 和 dv 维向量。所以基本上,据我所知,伪代码应该是这样的:

Q,K & V are d-model dimensional vectors.

for i in range(h):
   Qi = Dense(dk)(Q)
   Ki = Dense(dk)(K)
   Vi = Dense(dv)(V)
   Ai = Attention(Qi, Ki, Vi)

A0, A1, A2, ..., Ah are then concatenated.

这是正确的吗?还是我在这里遗漏了什么?

2个回答

原则上,伪代码是正确的,但不是它是如何实现的。投影和点积注意力可以有效地使用矩阵乘法同时对所有头部进行。

您可以只使用一层密集层进行查询,一层用于键,一层用于值,而不是在头上循环。例如,对于键,你会做一个密集的维度层hdk然后重塑它。

假设您有批量大小b, 序列长度l和模型尺寸dm

  • 密集层的输入是形状b×l×dm.

  • 密集层的输出是有形状的b×l×hdk(或者dv分别)。

  • 然后你可以重塑查询和键的形状b×l×h×dk.

现在,如果您排列维度,使查询具有形状b×h×l×dk钥匙有形状b×h×dk×l,您可以在最后两个维度进行批量矩阵乘法,并以形状的注意力能量结束b×h×l×l. 然后,如果你在最后一个维度做 softmax,你会得到每个 head 和每个查询的注意力分布。

通过与现在具有转置维度的投影值进行批量矩阵乘法b×h×l×dv,你得到加权平均值。所以最后,所有头的“连接”实际上只是另一个张量重塑。

@Jindřich 是完全正确的。这本身不是一个答案,而是一些指向他提到的每个项目的注释转换器中的实现的指针:

密集层的输入是形状b×l×dm

  • torch.from_numpy(np.random.randint(1, V, size=(batch, l)))b×l
  • self.lut = nn.Embedding(vocab, d_model); self.lut(x)b×l×dm

密集层的输出具有相同的形状b×l×dm=b×l×hdk

  • self.linears = clones(nn.Linear(d_model, d_model), 4): 这些都是WQ,WK,WV,WO分别,它们的输出是b×l×dm
  • l(x).view(nbatches, -1, self.h, self.d_k): 将输出转换为b×l×h×dk

然后你可以重塑查询和键的形状b×h×l×dk

  • l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2): 将输出转换为b×h×l×dk,为 K、Q 和 V 完成。

现在,如果你改变尺寸......

  • scores = torch.matmul(query, key.transpose(-2, -1))[b×h×l×dk]×[b×h×dk×l]=[b×h×l×l]
  • scores = scores.masked_fill(mask == 0, -1e9):面具是b×1×1×l对于编码器层,以及b×1×l×l对于解码器层(实际上l1但让我们忽略这一点)。两者都可以广播到分数张量,编码器掩码对于所有头部和所有位置都是相同的,而解码器掩码对于每个位置都是不同的,因为它隐藏了后续位置。

然后,如果你在最后一个维度做 softmax ......

  • p_attn = F.softmax(scores, dim = -1)b×h×l×l,但现在每个向量总和为 1

通过进行批量矩阵乘法......你得到加权平均值

  • torch.matmul(p_attn, value)[b×h×l×l]×[b×h×l×dk]=[b×h×l×dk],这是我们想要的加权平均值

所以最后,所有头的“连接”实际上只是另一个张量重塑

  • x.transpose(1, 2).contiguous()[b×l×h×dk]
  • x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)[b×l×dm]