TensorFlow / Keras:LSTM 层中的 stateful = True 是什么?

数据挖掘 深度学习 张量流 lstm rnn 格鲁
2021-09-26 19:45:58

你能详细说明这个论点吗?我发现文档中的简短解释令人不满意:

有状态的:布尔值(默认为 False)。如果为 True,则批次中索引 i 处每个样本的最后状态将用作下一批中索引 i 的样本的初始状态。

还有,什么时候stateful = True选?有哪些实际使用案例?

1个回答

该标志用于截断随时间的反向传播:梯度通过 LSTM 的隐藏状态在批次中的时间维度上传播,然后在下一批中,最后的隐藏状态用作输入状态长短期记忆法。

这允许 LSTM 在训练时使用更长的上下文,同时限制梯度计算的回退步数。

我知道这很常见的两种情况:

  • 语言建模(LM)。
  • 时间序列建模。

训练集是一个序列列表,可能来自一些文档 (LM) 或完整的时间序列。在数据准备期间,将创建批次,以便批次中的每个序列都是前一批中相同位置的序列的延续。这允许在计算预测时具有文档级/长时间序列上下文。

在这些情况下,您的数据长于批处理中的序列长度维度。这可能是由于可用 GPU 内存的限制(因此限制了最大批量大小)或由于任何其他原因而设计的。

更新:请注意,该stateful标志会影响训练和推理时间。如果禁用它,则必须确保在推理时每个预测都获得先前的隐藏状态。为此,您可以使用创建一个新模型stateful=True并从训练模型中复制参数,或者手动model.set_weights()传递它由于这种不便,有些人简单地设置always 并通过调用来强制模型在训练期间不使用存储的隐藏状态stateful = Truemodel.reset_states()