注意力机制究竟是什么?

机器算法验证 时间序列 深度学习 lstm 循环神经网络 注意力
2022-02-09 02:29:47

在过去的几年中,注意力机制已被用于各种深度学习论文中。Open AI 研究负责人 Ilya Sutskever 热情地称赞了他们: https ://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

Purdue 大学的 Eugenio Culurciello 声称应该放弃 RNN 和 LSTM,取而代之的是纯粹的基于注意力的神经网络:

https://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

这似乎有些夸张,但不可否认的是,纯粹基于注意力的模型在序列建模任务中做得相当好:我们都知道谷歌的那篇名副其实的论文,Attention is all you need

然而,究竟什么基于注意力的模型?我还没有找到对这些模型的明确解释。假设我想在给定历史值的情况下预测多元时间序列的新值。很清楚如何使用具有 LSTM 单元的 RNN 来做到这一点。我如何对基于注意力的模型做同样的事情?

1个回答

注意力是一种将一组向量聚合成一个向量的方法,通常通过查找向量通常,要么是模型的输入,要么是先前时间步长的隐藏状态,或者是向下一层的隐藏状态(在堆叠 LSTM 的情况下)。viuvi

结果通常称为上下文向量,因为它包含与当前时间步长相关的上下文。c

然后这个额外的上下文向量也被输入到 RNN/LSTM 中(它可以简单地与原始输入连接)。因此,上下文可用于帮助预测。c

最简单的方法是计算概率向量其中是所有先前的串联。一个常见的查找向量是当前隐藏状态p=softmax(VTu)c=ipiviVviuht

这有很多变化,您可以根据需要使事情变得复杂。例如,代替使用作为 logits,可以选择,其中是任意神经网络。viTuf(vi,u)f

序列到序列模型的常见注意机制使用,其中是编码器的隐藏状态,是当前隐藏解码器的状态。和两个都是参数。p=softmax(qTtanh(W1vi+W2ht))vhtqW

一些论文展示了注意力概念的不同变化:

指针网络使用对参考输入的注意力来解决组合优化问题。

循环实体网络在阅读文本时为不同的实体(人/对象)维护单独的记忆状态,并使用注意力更新正确的记忆状态。

变压器模型也广泛使用注意力。他们对注意力的表述稍微更一般,还涉及到关键向量:注意力权重实际上是在关键和查找之间计算,然后用构建上下文。kipvi


这是一种注意力形式的快速实现,尽管除了它通过了一些简单的测试之外,我不能保证正确性。

基本 RNN:

def rnn(inputs_split):
    bias = tf.get_variable('bias', shape = [hidden_dim, 1])
    weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, hidden_dim]), [batch, 1, 1])
    weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])

    hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
    for i, input in enumerate(inputs_split):
        input = tf.reshape(input, (batch, in_dim, 1))
        last_state = hidden_states[-1]
        hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
        hidden_states.append(hidden)
    return hidden_states[-1]

注意,我们在计算新的隐藏状态之前只添加了几行代码:

        if len(hidden_states) > 1:
            logits = tf.transpose(tf.reduce_mean(last_state * hidden_states[:-1], axis = [2, 3]))
            probs = tf.nn.softmax(logits)
            probs = tf.reshape(probs, (batch, -1, 1, 1))
            context = tf.add_n([v * prob for (v, prob) in zip(hidden_states[:-1], tf.unstack(probs, axis = 1))])
        else:
            context = tf.zeros_like(last_state)

        last_state = tf.concat([last_state, context], axis = 1)

        hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )

完整代码