Pytorch Luong 全球关注:对齐向量的形状应该是什么?

数据挖掘 lstm 火炬 注意机制
2022-02-18 04:54:10

我正在看关于注意力模型和全球注意力的 Luong 论文。我了解如何根据编码器隐藏状态和解码器隐藏状态的点积计算对齐向量。所以这一切都是有道理的。

我的问题是,对齐分数张量的维度应该是多少?如果我有批量数据,我基本上会为隐藏状态下的每个时间步计算一个分数,对吧。那么对齐向量的维度应该是 [sequence length, 1] 还是类似的?然后我会对这个对齐向量进行softmax,并将它乘以批处理中的每个元素来计算上下文向量,对吧。

同样,我的关键问题是对齐分数向量或张量的维度应该是多少。谢谢。

1个回答

[sequence length, 1]假设您使用一个句子,您自己的回答是正确的。(或者实际上,一维取决于实现。)

在实践中,数据通常是批处理的,因此它将是[batch, sequence length 1]. 这可以逐元素乘以维度的编码器状态[batch, sequence length, hidden size]并在中间维度求和以获得维度的上下文向量[batch, hidden_size]

还要注意,如果句子的长度不同,则需要处理填充,即将填充的位置设置-inf在softmax之前。