我正在使用 RNN-VAE 架构实现一个序列 2 序列模型,并且我使用了注意力机制。我在解码器部分有问题。我正在努力解决这个错误:IndexError: list index out of range 当我运行这段代码时:
decoder_inputs = Input(shape=(len_target,))
decoder_emb = Embedding(input_dim=vocab_out_size, output_dim=embedding_dim)
decoder_lstm = LSTM(units=units, return_sequences=True, return_state=True)
decoder_lstm_out, _, _ = decoder_lstm(decoder_emb(decoder_inputs),
initial_state=encoder_states)
print("enc_outputs", encoder_outputs.shape) # ==> (?,256)
print("decoder_lstm_out", decoder_lstm_out.shape)# ==> (?,12,256)
print("zzzzzz", z.shape) # ==> (?,256)
attn_layer = AttentionLayer(name='attention_layer')
attn_out, attn_states = attn_layer([z,z], decoder_lstm_out)
错误在最后一行提出,并给出了回溯:
Traceback (most recent call last):
File "malek_tuto.py", line 197, in <module>
attn_out, attn_states = attn_layer([z,z], decoder_lstm_out)
File "C:\Users\lightland\Anaconda3\lib\site-
packages\tensorflow\python\keras\engine\base_layer.py", line 728, in
__call__ self.build(input_shapes)
File "D:\PFE\Contribution\modele\layers\attention.py", line 24, in
build shape=tf.TensorShape((input_shape[0][3], input_shape[0][3])),
File "C:\Users\lightland\Anaconda3\lib\site-
packages\tensorflow\python\framework\tensor_shape.py", line 615, in
__getitem__ return self._dims[key]
IndexError: list index out of range
在 AttentionLayer 类中,构建函数 id 定义为:
def build(self, input_shape):
assert isinstance(input_shape, list)
print("hhhhhhhhhh",input_shape)
print("jjknkjnjk")
# Create a trainable weight variable for this layer.
self.W_a = self.add_weight(name='W_a',
shape=tf.TensorShape((input_shape[0][2],
input_shape[0][2])),
initializer='uniform',
trainable=True)
self.U_a = self.add_weight(name='U_a',
shape=tf.TensorShape((input_shape[1][2],
input_shape[0][2])),
initializer='uniform',
trainable=True)
self.V_a = self.add_weight(name='V_a',
shape=tf.TensorShape((input_shape[0][2], 1)),
initializer='uniform',
trainable=True)
super(AttentionLayer, self).build(input_shape)
如果有人可以帮助我,我将非常感激,我无法理解问题出在哪里,以及如何解决它。
提前致谢