Transformer 中解码器掩码(三角形掩码)的用途是什么?

人工智能 自然语言处理 变压器 注意力
2021-11-06 04:58:35

我正在尝试使用本教程实现变压器模型。在 Transformer 模型的解码器块中,将掩码传递给“在解码器接收的输入中填充和掩码未来的令牌”。这个掩码被添加到注意力权重中。

import tensorflow as tf

def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask

现在我的问题是,这一步(在注意力权重中添加掩码)如何等同于逐个显示要建模的单词?我根本无法直观地理解它的作用。大多数教程甚至都不会提到这一步,因为它非常明显。请帮我理解。谢谢。

3个回答

本教程中介绍的 Transformer 模型是一个自回归 Transformer。这意味着对下一个令牌的预测仅取决于它之前的令牌。

因此,为了预测下一个令牌,您必须确保只参加前一个令牌。(如果不是,这将是作弊,因为模型已经知道接下来会发生什么)。

所以注意掩码会像这样
[0, 1, 1, 1, 1]
[0, 0, 1, 1, 1]
[0, 0, 0, 1, 1]
[0, 0, 0, 0, 1 ]
[0, 0, 0, 0, 0]

例如:如果您正在将英语翻译成西班牙语
输入:您好吗?
目标:<开始> Como estas ? <end>
然后解码器会预测类似这样的
<start> (它将作为初始令牌提供给解码器)
<start> Como
<start> Como estas
<start> Como estas ?
<开始> Como estas ? <结束>

现在将此逐步预测序列与上面给出的注意力掩码进行比较,现在对您来说很有意义

我们在训练模型时将目标输入提供给变压器解码器。因此,模型很容易“窥视”并了解下一个单词是什么。为了确保不会发生这种情况,我们在 Query 和 Key 之间的点积之后应用了一个附加掩码。在原始论文“Attention is all you need”中,三角形矩阵在下三角形中有 0,在上三角形中有 -10e9(您可以看到最近的示例中使用的负无穷大)。因此,当将掩码添加到注意力分数时,上三角形的注意力分数会非常低。当这个矩阵通过 softmax 函数时,这些非常低的值变得接近于 0,这本质上意味着不关注时间步 t 之后的单词。放入矩阵格式,

[8.1, 0.04, 5.2, 4.2]
[0.5, 9.2, 2.33, 0.7]
[0.2, 0.4, 6.11, 1.0]
[3.1, 2.1. 2.19, 8.1]

让上面的矩阵A是查询和键之间的点积的结果。A[0][0]包含查询的第一个词对键的第一个词的注意力分数,包含查询的第一个词对键的第二A[0][1]词的注意力分数,依此类推。如您所见,在添加掩码并在 上执行 softmax 之后A,结果将是,

[8.1, 0.0, 0.0, 0.0]
[0.5, 9.2, 0.0, 0.0]
[0.2, 0.4, 6.11, 0.0]
[3.1, 2.1. 2.19, 8.1]

这迫使变形金刚只关注它之前的单词。您可以查看 CS224n 中提供的 Transformer 讲座以获取完整的详细信息。

在使用其注意力机制时,需要使用掩码来防止解码器在训​​练期间“窥视”地面实况。

编码器:

  • 运行时或训练:

    编码器总是在一次迭代中发生,因为它将单独处理所有嵌入,但并行处理。这有助于我们节省时间。


解码器:

  • 运行:

    在这里,解码器将在几次非并行迭代中运行,在每次迭代中生成一个“输出”嵌入。然后可以将其输出用作下一次迭代的输入。

  • 训练:

    在这里,解码器可以在一次迭代中完成所有这些,因为它只是从我们那里接收“基本事实”。因为我们事先知道这些“真相”嵌入,所以可以将它们作为行存储到矩阵中,然后可以将它们提交给解码器进行单独处理,但可以并行处理。

    正如您在训练期间所看到的,解码器的实际预测不用于构建目标序列(就像 LSTM 那样)。取而代之的是,这里使用的是一种称为“教师强迫”的标准程序。

    正如其他人所说,在使用其注意力机制时,需要使用掩码来防止解码器在训​​练期间“窥视”地面实况。

提醒一下,在转换器中,嵌入在输入期间永远不会连接。相反,每个单词分别但同时流经编码器和解码器。

另外,请注意掩码包含负无穷大而不是零这是由于 Softmax 在 Attention 中的工作方式。

我们总是首先运行编码器,它总是需要 1 次迭代。然后编码器耐心地坐在一边,因为解码器根据需要使用它的值。