tf.nn.dynamic_rnn() 的输出是什么?

机器算法验证 深度学习 lstm 张量流 循环神经网络 格鲁
2022-03-08 19:20:21

我不确定我从官方文档中了解的内容,其中说:

返回:一对(输出,状态),其中:

outputs:RNN 输出张量。

如果time_major == False(默认),这将是一个张量形状: [batch_size, max_time, cell.output_size]

如果time_major == True,这将是一个张量形状:[max_time, batch_size, cell.output_size]

请注意,如果cell.output_size是整数或 TensorShape 对象的(可能嵌套的)元组,则输出将是与 cell.output_size 具有相同结构的元组,其中包含具有与 中的形状数据对应的形状的张量cell.output_size

state: 最终状态。如果 cell.state_size 是一个 int,这将是 shape [batch_size, cell.state_size]如果它是一个 TensorShape,这将是 shape [batch_size] + cell.state_size如果它是整数或 TensorShape 的(可能是嵌套的)元组,这将是一个具有相应形状的元组。如果单元格是 LSTMCells,则状态将是一个元组,其中包含每个单元格的 LSTMStateTuple。

]是否output[-1总是(在所有三种单元类型中,即 RNN、GRU、LSTM)等于状态(返回元组的第二个元素)?我想各地的文献在使用“隐藏状态”一词时都过于自由了。三个单元格中的隐藏状态是否得分出来(为什么它被称为隐藏超出了我,它会出现LSTM中的单元格状态应该称为隐藏状态,因为它没有暴露)?

2个回答

是的,单元输出等于隐藏状态。在 LSTM 的情况下,它是元组的短期部分( 的第二个元素LSTMStateTuple),如下图所示:

长短期记忆体

但是对于tf.nn.dynamic_rnn,当序列较短(参数)时,返回的状态可能不同。sequence_length看看这个例子:

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])

basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3
])
seq_length_batch = np.array([2, 1, 2, 2])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val, states_val = sess.run([outputs, states], 
                                     feed_dict={X: X_batch, seq_length: seq_length_batch})

  print(outputs_val)
  print()
  print(states_val)

这里输入批次包含 4 个序列,其中一个很短并用零填充。运行后你应该是这样的:

[[[ 0.2315362  -0.37939444 -0.625332   -0.80235624  0.2288385 ]
  [ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]]

 [[ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
  [ 0.          0.          0.          0.          0.        ]]

 [[ 0.9994331   0.9929737  -0.8311569  -0.99928087  0.9990415 ]
  [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]]

 [[ 0.9962312   0.99659646  0.98880637  0.99548346  0.9997809 ]
  [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]]

[[ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]
 [ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
 [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]
 [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]

...这确实表明state == output[1]对于完整序列和state == output[0]短序列。也是output[1]这个序列的零向量。LSTM 和 GRU 单元也是如此。

所以state是一个方便的张量,它保存最后一个实际的RNN 状态,忽略零。output张量包含所有单元格的输出,因此它不会忽略零。这就是他们两个都退货的原因。

https://stackoverflow.com/questions/36817596/get-last-output-of-dynamic-rnn-in-tensorflow/49705930#49705930的可能副本

无论如何,让我们继续回答。

此代码片段可能有助于了解该dynamic_rnn层真正返回的内容

=> (outputs, final_output_state)的元组。

因此,对于最大序列长度 为 T 个时间步长的输入,输出的形状为[Batch_size, T, num_inputs](给定time_major=False;默认值),并且它包含每个时间步的输出状态h1, h2.....hT

final_output_state具有形状[Batch_size,num_inputs],具有每个批次序列的最终单元状态cT和输出状态hT

但是由于dynamic_rnn正在使用,我的猜测是您的序列长度因每批而异。

    import tensorflow as tf
    import numpy as np
    from tensorflow.contrib import rnn
    tf.reset_default_graph()

    # Create input data
    X = np.random.randn(2, 10, 8)

    # The second example is of length 6 
    X[1,6:] = 0
    X_lengths = [10, 6]

    cell = tf.nn.rnn_cell.LSTMCell(num_units=64, state_is_tuple=True)

    outputs, states  = tf.nn.dynamic_rnn(cell=cell,
                                         dtype=tf.float64,
                                         sequence_length=X_lengths,
                                         inputs=X)

    result = tf.contrib.learn.run_n({"outputs": outputs, "states":states},
                                    n=1,
                                    feed_dict=None)
    assert result[0]["outputs"].shape == (2, 10, 64)
    print result[0]["outputs"].shape
    print result[0]["states"].h.shape
    # the final outputs state and states returned must be equal for each      
    # sequence
    assert(result[0]["outputs"][0][-1]==result[0]["states"].h[0]).all()
    assert(result[0]["outputs"][-1][5]==result[0]["states"].h[-1]).all()
    assert(result[0]["outputs"][-1][-1]==result[0]["states"].h[-1]).all()

最终断言将失败,因为第二个序列的最终状态是在第 6 个时间步,即。索引 5 和 [6:9] 的其余输出在第二个时间步中都是 0