简单 RNN 中的参数数量

数据挖掘 rnn
2022-02-20 04:18:56

拜托,我被卡住了,我无法理解简单 RNN 的参数数量,这里是示例和模型摘要。这个例子很简单:

x = np.linspace(0,50,501)
y= np.sin(x)
df= pd.DataFrame(data=y, index=x, columns=['Sinus'])

然后我会构建一个简单的 RNN 来预测这个正弦波,

test_percent = 0.1
test_point= np.round(len(df)*test_percent)
test_ind = int(len(df)-test_point)
train = df.iloc[:test_ind]
test = df.iloc[test_ind:]

from sklearn.preprocessing import MinMaxScaler
scaled_train = scaler.transform(train)
scaled_test = scaler.transform(test)

from tensorflow.keras.preprocessing.sequence import TimeseriesGenerator
generator = TimeseriesGenerator(scaled_train, scaled_train, length=50, batch_size=1)

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, SimpleRNN

n_features=1
model= Sequential()
model.add(SimpleRNN(units=50, input_shape=(50,1)))
model.add(Dense(1))
model.compile(loss='mse', optimizer='adam')
model.summary()

模型总结

模型摘要中的数字2600我看不懂!因为根据 RNNS 的这种直觉 RNN 模型

在我的例子中,U 的维度是 (50,1),所以 50 个权重,然后对于 V,也是 50 个权重,对于 W,也是 50 个权重,对我来说,只有 150 个参数,为什么这个数字在摘要中:2600 个参数?

1个回答

Keras SimpleRNN 是一个完全连接的 RNN,因此实际上每个单元都与所有其他单元连接。所以方程变为:

  • (input_feature +1) x 单位 + 单位 x 单位
  • 2x50 +2500
  • +1 来自偏见

在此处输入图像描述