如何在 Keras 中构建一个循环神经网络,其中每个输入都先经过一层?

数据挖掘 喀拉斯 rnn
2021-10-05 19:50:52

我正在尝试在 Keras 中构建一个如下所示的神经网络:

在此处输入图像描述

在哪里 x1, x2, ... 是经过相同变换的输入向量 f. f本身就是一个必须学习其参数的层。序列长度n 在实例之间是可变的。

我在这里无法理解两件事:

  1. 输入应该是什么样的?
    我正在考虑一个形状为 (number_of_x_inputs, x_dimension) 的二维张量,其中 x_dimension 是单个向量的长度x. 这样的二维张量可以具有可变形状吗?我知道张量可以具有用于批处理的可变形状,但我不知道这是否对我有帮助。

  2. 在将每个输入向量馈送到 RNN 层之前,如何通过相同的转换传递它?
    有没有办法扩展例如 GRU 以便f在通过实际的 GRU 单元之前添加层?

1个回答
  1. 输入应该是什么样的?

您认为 2D 张量是正确的,但通常我们会为批次添加一个维度。您确实可以拥有可变长度的 number_of_x_inputs,但要在批处理期间进行训练,单个批次中的所有输入都需要具有相同的形状。(将批量大小设置为 1 可以解决这个问题。)在推理期间,您可以拥有任何您想要的长度。请参见下面的代码示例。

  1. 在将每个输入向量馈送到 RNN 层之前,如何通过相同的转换传递它?

使用TimeDistributed. 下面的示例传递所有向量xi通过相同的前馈网络(Dense(5, ...)),但您应该能够将其换成f您的想法。

from keras.models import Sequential
from keras.layers import LSTM, Dense, TimeDistributed

x_dimension = 16
num_classes = 2

model = Sequential()

model.add(TimeDistributed(Dense(5, activation='relu'),
          input_shape=(None, x_dimension)))
model.add(LSTM(32, return_sequences=True))
model.add(LSTM(8))
model.add(Dense(num_classes, activation='softmax'))

print(model.summary(90))

这将打印以下模型:

Layer (type)                            Output Shape                        Param #
==========================================================================================
time_distributed_1 (TimeDistributed)    (None, None, 5)                     85
__________________________________________________________________________________________
lstm_1 (LSTM)                           (None, None, 32)                    4864
__________________________________________________________________________________________
lstm_2 (LSTM)                           (None, 8)                           1312
__________________________________________________________________________________________
dense_2 (Dense)                         (None, 2)                           18
==========================================================================================
Total params: 6,279
Trainable params: 6,279
Non-trainable params: 0
__________________________________________________________________________________________