如何在 Keras 神经网络上定义权重

数据挖掘 Python 神经网络 喀拉斯
2022-02-19 04:23:28

我有一个用 Keras 编写的具有 (8,5,5,5,32) 个神经元的神经网络模型,如下所示:

# Sequential
model = Sequential()

# Neural network
model.add(Dense(5, input_dim=len(X[0]), activation='sigmoid' ))
model.add(Dense(5, activation='sigmoid' ))
model.add(Dense(5, activation='sigmoid' ))
model.add(Dense(len(y[0]), activation='sigmoid' ))


# Compile model
# sgd = optimizers.SGD(lr=0.01, decay=0.0, momentum=0.0, nesterov=False)
model.compile(loss='mean_squared_error', optimizer='adam', metrics=['acc'])

# Fit model
history = model.fit(X, y, nb_epoch=200, validation_split=0.2, batch_size=30)

当我在训练后访问权重时,我得到:

model.get_weights()

它返回一个长数组,对应于每层神经元之间的权重。我不明白的是较小的数组表示(与每层的神经元数量相同的权重数量)在较大的数组(与每层的神经元数量相同的子阵列数量)之间。

这是神经元所经历的偏差吗?

以下是“model.get_weights()”的结果:

[array([[ 0.31680015,  0.22357693, -0.63079047, -0.04600599, -0.26949674],
        [ 0.0525099 , -0.41120723,  0.28259486, -0.37071031,  0.42028651],
        [ 0.35435981, -0.35501873, -0.05099263,  0.3633016 , -0.64845532],
        [ 0.60027206,  0.40594664,  0.29894602, -0.13255124,  0.52797431],
        [-0.32299024,  0.54219592,  0.34114835, -0.59672344, -0.47126439],
        [-0.51338726,  0.64451784, -0.35283062,  0.47248691, -0.31077194],
        [-0.59289241, -0.207461  ,  0.00371859, -0.52090681,  0.10946763],
        [-0.37216368, -0.23905358, -0.38580573,  0.0401655 , -0.34231418]], dtype=float32),
 array([-0.03448975,  0.04991768,  0.11635038, -0.04274927,  0.0325128 ], dtype=float32),   # WHAT IS THIS ARRAY REPRESENTING?
 array([[ 0.10943446, -0.24749899,  0.58269709,  0.54208171, -0.05888808],
        [ 0.02320727,  0.08465887, -0.79114383, -0.19608408,  0.55898732],
        [-0.81141329,  0.19124934, -0.69268334, -0.44021448,  0.72605485],
        [ 0.32895803,  0.08196118,  0.53820646,  0.6348688 , -0.06715827],
        [-0.0850288 ,  0.5077976 ,  0.36972848,  0.44874495,  0.36402631]], dtype=float32),
 array([-0.30626452,  0.15301916, -0.17855364,  0.12410269,  0.2502442 ], dtype=float32),
 array([[-0.31872895, -0.46534333, -0.4664084 , -0.23720025,  0.30465502],
        [-0.37690881, -0.00396255,  0.38115206,  1.20845091,  0.69348788],
        [ 0.15064301, -0.29923961,  0.13108611, -0.29579154, -0.34181508],
        [ 0.62893951,  0.49498206,  0.02549251,  0.6561147 , -0.52280194],
        [ 1.07029617,  0.66126752,  0.50944209,  0.58811921, -0.04030331]], dtype=float32),
 array([ 0.57747394,  0.59574115,  0.68391681,  0.78335029,  0.26046163], dtype=float32),
 array([[ -1.09961331e+00,  -7.06328213e-01,  -7.93772519e-01,
          -5.50112486e-01,  -5.05448937e-01,  -4.78618711e-01,
          -3.32313687e-01,  -5.46549559e-01,   3.88661295e-01,
           3.64094585e-01,  -1.93489313e-01,  -7.61669278e-02,
           1.23761639e-01,   4.93125141e-01,   4.78168607e-01,
          -7.07402304e-02,  -2.54306406e-01,  -2.37895012e-01,
          -1.36467636e-01,  -7.16407061e-01,   1.32701367e-01,
          -4.23079096e-02,  -4.71717492e-03,   2.56372184e-01,
           1.89603701e-01,   2.20276624e-01,  -2.55215704e-01,
          -1.04997739e-01,   2.81909049e-01,  -7.00806752e-02,
           2.99933195e-01,   3.84294897e-01],
        [ -9.60830688e-01,  -6.18087530e-01,  -8.46309483e-01,
          -6.05162561e-01,  -2.39057451e-01,   3.03133931e-02,
          -2.07459703e-01,  -6.84834659e-01,   3.47823203e-01,
          -6.44357502e-02,   2.55657077e-01,  -2.37801671e-01,
           2.57411227e-02,  -1.01771923e-02,  -6.76048512e-04,
           3.06296106e-02,   2.05646217e-01,   8.02281871e-02,
          -4.01538044e-01,  -5.49115181e-01,   3.00252885e-01,
           3.31445992e-01,  -1.75046876e-01,  -3.36513370e-01,
           1.65666446e-01,   1.74015135e-01,  -2.15066984e-01,
           3.79294544e-01,   1.67991996e-01,   2.39770293e-01,
          -5.49201332e-02,  -1.67401493e-01],
        [ -1.24525476e+00,  -5.97414970e-01,  -1.36500984e-01,
          -5.60880482e-01,  -4.26550537e-01,   1.46522000e-01,
          -6.26730978e-01,  -8.33723724e-01,  -2.20034972e-01,
           3.54697943e-01,   2.86612272e-01,   2.60758907e-01,
          -5.10771237e-02,   1.91444799e-01,   3.32518548e-01,
           1.51452944e-01,  -2.18744278e-01,  -2.07690187e-02,
           9.60563496e-02,  -2.26809219e-01,   8.80904198e-02,
           2.33646557e-01,   2.45599806e-01,   2.53560930e-01,
           1.55982673e-01,   6.49829209e-01,  -1.26019821e-01,
           5.47675073e-01,   3.28564644e-01,   8.67465809e-02,
          -1.40921310e-01,  -2.35581279e-01],
        [ -1.51756608e+00,  -1.11596704e+00,  -4.97624874e-01,
          -4.95555460e-01,  -4.83801186e-01,  -2.23367065e-01,
          -1.08115244e+00,  -9.68795598e-01,   4.32208836e-01,
           1.16957083e-01,  -1.02919623e-01,  -8.19303747e-03,
           4.21310306e-01,   1.09493546e-01,   1.54512182e-01,
          -1.46762207e-01,   1.58293337e-01,  -3.95552874e-01,
          -2.07770184e-01,  -1.90177709e-01,   2.07072627e-02,
           6.61122620e-01,   5.44478893e-01,  -1.46910429e-01,
           4.22070086e-01,   2.49319345e-01,   6.19665794e-02,
           9.74300727e-02,   3.37298632e-01,   2.90907085e-01,
           8.78930092e-02,   1.25872776e-01],
        [ -7.40014732e-01,  -7.02967405e-01,  -1.42469540e-01,
           1.66655079e-01,  -1.59682855e-01,  -2.07361296e-01,
          -9.04432237e-02,  -1.22986667e-01,  -3.28961462e-01,
           9.21192244e-02,  -2.13514805e-01,  -1.59033865e-01,
          -2.79709876e-01,  -1.64602488e-01,   1.96248814e-01,
          -1.98676869e-01,   2.80951142e-01,  -4.50290412e-01,
           2.34707281e-01,  -3.13370705e-01,  -2.24865358e-02,
           2.63352484e-01,   4.90205318e-01,   1.96813330e-01,
           4.13736820e-01,   7.11815134e-02,  -1.92510381e-01,
           5.77223562e-02,  -4.25750390e-02,  -3.79479416e-02,
           6.17611647e-01,   3.11740283e-02]], dtype=float32),
 array([-1.66324592, -1.4556272 , -1.10267663, -0.52273899, -0.59953606,
        -0.52498704, -1.15079761, -1.19500864,  0.13858946, -0.01602219,
         0.02289196,  0.2923668 ,  0.14240141,  0.19787838,  0.15867165,
         0.24876554,  0.21616565, -0.35774839, -0.40798596, -0.71499836,
         0.36547631,  0.48771939,  0.40948448,  0.01851768,  0.56893039,
         0.93250728, -0.32663229,  0.34419572,  0.15536141,  0.21145843,
         0.53275597,  0.0364211 ], dtype=float32)]
1个回答

你是对的!您在权重之间看到的数组是起始层中节点的偏差。请注意该数组中有 5 个元素,这与您的第一个隐藏层中的 5 个节点一致。