Transformer 如何预测未来的 n 步?

数据挖掘 张量流 火炬 变压器 序列到序列
2022-03-05 22:37:55

我几乎找不到 Transformer 的实现(既不臃肿也不令人困惑),而我用作参考的是 PyTorch 实现。但是,Pytorch 实现要求您为每一步传递输入 ( src ) 和目标 ( tgt ) 张量,而不是对输入进行一次编码并继续迭代n步以生成完整输出。我在这里错过了什么吗?

我的第一个猜测是,Transformer技术上不是 seq2seq 模型,我不明白我应该如何实现它,或者我在过去几年里一直在错误地实现 seq2seq 模型:)

1个回答

Transformer 是一个 seq2seq 模型。

训练时,您将源令牌和目标令牌都传递给 Transformer 模型,就像您对 LSTM 或 GRU 使用教师强制执行的操作一样,这是训练它们的默认方式。请注意,在 Transformer 解码器中,我们需要应用掩码来避免依赖于当前和未来令牌的预测。

推理时,我们没有目标标记(因为这是我们试图预测的)。在这种情况下,第一步中的解码器输入只是序列 [],我们将预测第一个标记。然后,我们将为下一个时间步准备输入,并将预测附加到前一个时间步输入(即 []),然后我们将获得第二个标记的预测。等等。请注意,在每个时间步,我们都在重复过去位置的计算;在实际实现中,这些状态被缓存而不是在每个时间步重新计算。

关于一些说明 Transformer 工作原理的 Python 代码,我建议使用带注释的 Transformer,这是一个很好的指导,可以真正实现。您可能run_epoch对训练函数greedy_decode和推理函数最感兴趣。

def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len-1):
        out = model.decode(memory, src_mask, 
                           Variable(ys), 
                           Variable(subsequent_mask(ys.size(1))
                                    .type_as(src.data)))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.data[0]
        ys = torch.cat([ys, 
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    return ys

greedy_decode您可以看到当前时间步的预测如何连接到输入以创建以下时间步的输入。