在 TensorFlow 2.0 中,LSTM 和 LSTMCell 对象有什么区别?

数据挖掘 Python 张量流 lstm rnn
2021-10-06 03:23:35

我正在尝试在 TensorFlow 2.0 (beta1) 中实现 RNN。查看层函数(继承自 Keras)我发现:

tf.keras.layers.LSTM

tf.keras.layers.LSTMCell

两者有什么不同?如果你看看他们的论点,他们似乎是一样的。

1个回答

你是对的 - 差异很小。基本的 LSTMCell 类实现了所需的主要功能,例如build方法,而 LSTM 类只包含一个入口点:call方法,以及一堆用于检索属性值的getter 。LSTMCell是基类,用作LSTM内部使用的单元。

所有链接都指向tensorflow.keras源代码的相关部分。

我的建议是在模型中使用标准 LSTM 类作为普通层。如果您有 GPU 可供使用,您可能希望使用经过 CUDA 优化的层版本在 GPU 上执行。根据文档:

请注意,此单元未针对 GPU 上的性能进行优化。tf.keras.layers.CuDNNLSTM在 GPU 上使用以获得更好的性能。

也有GRU一层,也有CuDNNGRU一层。


如果你想调整底层的工作方式,你可以创建一个类并从 LSTMCell 甚至基类继承:

from tensorflow.python.keras.engine.base_layer import Layer

class MyLSTM(Layer):
    pass

但是你必须为自己实现很多事情。