我正在尝试在 TensorFlow 2.0 (beta1) 中实现 RNN。查看层函数(继承自 Keras)我发现:
tf.keras.layers.LSTM
和
tf.keras.layers.LSTMCell
两者有什么不同?如果你看看他们的论点,他们似乎是一样的。
我正在尝试在 TensorFlow 2.0 (beta1) 中实现 RNN。查看层函数(继承自 Keras)我发现:
tf.keras.layers.LSTM
和
tf.keras.layers.LSTMCell
两者有什么不同?如果你看看他们的论点,他们似乎是一样的。
你是对的 - 差异很小。基本的 LSTMCell 类实现了所需的主要功能,例如build
方法,而 LSTM 类只包含一个入口点:call
方法,以及一堆用于检索属性值的getter 。LSTMCell是基类,用作LSTM类内部使用的单元。
所有链接都指向tensorflow.keras
源代码的相关部分。
我的建议是在模型中使用标准 LSTM 类作为普通层。如果您有 GPU 可供使用,您可能希望使用经过 CUDA 优化的层版本在 GPU 上执行。根据文档:
请注意,此单元未针对 GPU 上的性能进行优化。请
tf.keras.layers.CuDNNLSTM
在 GPU 上使用以获得更好的性能。
也有GRU
一层,也有CuDNNGRU
一层。
如果你想调整底层的工作方式,你可以创建一个类并从 LSTMCell 甚至基类继承:
from tensorflow.python.keras.engine.base_layer import Layer
class MyLSTM(Layer):
pass
但是你必须为自己实现很多事情。