如何使用 keras 预测数学进展

数据挖掘 机器学习 神经网络 喀拉斯 张量流 rnn
2022-02-26 17:02:45

我为多对多循环网络尝试以下模型:

                              [170]     [682]     [2730]
                                |         |         |
       o  -------  o  --------  o  -----  o  -----  o
     / | \      /  |  \      /  |  \
   [2, 2, 3], [5, 10, 11], [2, 42, 43]
     {t=1}       {t=2}       {t=3}      {t=4}     {t=5}  

模型应如下所示:x1* (x2+x3),其中 [ x1=constant ]。[x3=x2+1]。[ 上一次 {t-1} 的 x2=x1* (x2+x3) ]

此模型应在时间 t=4 时返回 [ x1* (x2+x3) ] 的结果;t=5;t=6。

好吧,由于这个测试模型没有给我正确的结果,我决定在 Keras 中尝试它,结果也是错误的。

我将代码留在 Keras 中,看看是否有人可以帮助我获得正确的结果并解释模型为什么不起作用。

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.models import Sequential

# Data train
X = np.array([ [ [3,3,4], [3,21,22], [3,129,130] ],
               [ [2,2,3], [2,10,11], [2, 42, 43] ],
               [ [2,2,3], [2,10,11], [2, 42, 43] ],
               [ [4,4,5], [4,36,37], [4,292,293] ],
               [ [5,5,6], [5,55,56], [5,555,556] ]
              ],    dtype=int)

Y = np.array([ [[777 ], [4665 ], [27993 ] ],
               [[170 ], [682  ], [2730  ] ],
               [[170 ], [682  ], [2730  ] ],
               [[2340], [18724], [149796] ],
               [[5555], [55555], [555555] ]
              ], dtype=int)

# Model
def get_compiled_model():
  timesteps = 3 #times
  data_dim = 3  #features
  N = 3         #salidas
  model = Sequential()

  model.add(tf.keras.layers.LSTM(1, input_shape=(timesteps, data_dim)))
  model.add(tf.keras.layers.RepeatVector(N))
  model.add(tf.keras.layers.LSTM(1, return_sequences=True))  
  model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(1)))
  model.add(tf.keras.layers.Activation('linear'))
  
  opt = keras.optimizers.Adam(learning_rate=0.9)
  model.compile(optimizer=opt, loss='MAE')
  return model

# main
model = get_compiled_model()
model.fit(X, Y, epochs=4000)

result = model.predict([[ [2,2,3], [2,10,11], [2,42,43] ]])
print(result)

事实是,我在 Keras 中使用了几种类型的层(密集层、SimpleRNN、LSTM)尝试了这个模型,但它们都没有奏效。

我留下损失曲线的图像

0个回答
没有发现任何回复~