双向RNN可以使用可变序列长度吗?

数据挖掘 神经网络 rnn 顺序
2021-09-28 16:03:01

双向 RNN 由两个 RNN 组成,一个用于前向序列方向,另一个用于后向序列方向,其结果在每个时间步连接。这种配置会限制模型始终使用固定的序列长度吗?还是它仍然可以作为单向 RNN 工作,可以应用于任何序列长度?

之所以提出这个问题,是因为双向架构在每个时间步合并了前向和后向 RNN 的输出。因此,如果序列长度为 4,前向和后向 RNN 的输出将以这种方式合并:第 1 个前向和第 4 个后向,第 2 个前向和第 3 个后向,……第 4 个前向和第一个后向。但是,如果使用不同的序列长度,则将修改此合并顺序:

假设网络使用序列长度 4 进行训练,但在测试时使用的序列长度为 5。合并将是:第 1 前进与第 5 后退,第 2 前进与第 4 后退……第 5 前进与第 1 后退。这种合并顺序的变化会对双向 RNN 性能产生负面影响吗?

2个回答

简短的回答是否定的,双向架构仍将采用可变序列长度。要了解原因,您应该了解填充的工作原理。

例如,假设您正在 tensorflow 中针对多个主题的可变长度时间序列数据实现双向 LSTM-RNN。输入是一个具有形状的 3D 数组:[n_subjects, [n_features, [n_timesteps...] ...] ...]因此,为了确保数组具有一致的尺寸,您可以将其他对象的特征填充到对象的长度,并使用最长时间测量的特征。

假设主题 1 具有一个values = [22,20,19,21,33,22,44,21,19,26,27]测量值为 的特征times = [0,1,2,3,4,5,6,7,8,9,10]主题 2 有一个特征,values = [21,12,22,30,13,42,20]测量值为times = [0,1,2,3,4,5,6]您可以通过扩展数组来填充主题 2 的功能,以便padded_values = [21,12,22,30,13,42,20,0,0,0,0]at times = [0,1,2,3,4,5,6,7,8,9,10],然后对每个后续主题执行相同的操作。

这意味着每个主题的时间步数可以是可变的,并且您引用的合并与该特定主题的维度一起发生。

下面是一个模型的双向 LSTM-RNN 架构示例,该模型使用在可变时间长度内测量的生物特征来预测不同受试者的睡眠阶段。

在此处输入图像描述

是的,TF 2.0 有一种方法,即使用Ragged Tensors,如下所示:

# Task: predict whether each sentence is a question or not.
sentences = tf.constant(
    ['What makes you think she is a witch?',
     'She turned me into a newt.',
     'A newt?',
     'Well, I got better.'])
is_question = tf.constant([True, False, True, False])

# Preprocess the input strings.
hash_buckets = 1000
words = tf.strings.split(sentences, ' ')
hashed_words = tf.strings.to_hash_bucket_fast(words, hash_buckets)

# Build the Keras model.
keras_model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=[None], dtype=tf.int64, ragged=True),
    tf.keras.layers.Embedding(hash_buckets, 16),
    tf.keras.layers.LSTM(32, use_bias=False),
    tf.keras.layers.Dense(32),
    tf.keras.layers.Activation(tf.nn.relu),
    tf.keras.layers.Dense(1)
])

keras_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
keras_model.fit(hashed_words, is_question, epochs=5)
print(keras_model.predict(hashed_words))

来源:https ://www.tensorflow.org/guide/ragged_tensor