用于顺序输入的 tf.nn.dynamic_rnn 及时反向传播(来自批处理)

数据挖掘 张量流 rnn lstm 反向传播
2022-02-13 06:41:34

我有这样的代码:

lstm_cell = tf.contrib.rnn.BasicLSTMCell(256, state_is_tuple = True)
c_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.c], "c_in")
h_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.h], "h_in")
rnn_state_in = (c_in, h_in)
rnn_in = tf.expand_dims(previous_layer, [0])
sequence_length = #size of my batch
rnn_state_in = tf.contrib.rnn.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(lstm_cell,
                                                    rnn_in,
                                                    initial_state = rnn_state_in,
                                                    sequence_length = sequence_length,
                                                    time_major = False)
lstm_c, lstm_h = lstm_state
rnn_out = tf.reshape(lstm_outputs, [-1, 256])

在这里,我使用 dynamic_rnn 来模拟批处理的时间步长。每次向前传球时,我都可以得到lstm_c, lstm_h可以存储在外面任何地方的东西。

因此,假设我已经对网络中的序列中的 N 个项目进行了前向传递,并且具有从 dynamic_rnn 提供的最终单元状态和隐藏状态。现在,要执行反向传播,我对 LSTM 的输入应该是什么?

默认情况下,反向传播是否在 dynamic_rnn 中跨时间步发生?

(比如说,时间步数 = batch_size=N)

因此,我提供以下输入是否足够:

sess.run(_train_op, feed_dict = {_state: np.vstack(batch_states),
                                            ...
                                            c_in: prev_rnn_state[0],
                                            h_in: prev_rnn_state[1]
})

(哪里prev_rnn_state是 的元组cell state, hidden state,我从上一批前向传播的 dynamic_rnn 中得到。)

或者我是否已经显式地跨时间序列展开 LSTM 层并通过提供单元状态向量和隐藏在先前时间序列中收集的向量来训练它?

0个回答
没有发现任何回复~