计算参数 keras 的数量

数据挖掘 机器学习 喀拉斯 美国有线电视新闻网
2022-03-04 19:11:16

我正在按照相同 链接上的 keras 教程在 keras 中实现 1D CNN 。建立模型后,当我执行 model.summary() 时,我得到以下输出。

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 1000)              0         
_________________________________________________________________
embedding_1 (Embedding)      (None, 1000, 100)         17410600  
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 996, 128)          64128     
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 199, 128)          0         
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 195, 128)          82048     
_________________________________________________________________
max_pooling1d_2 (MaxPooling1 (None, 39, 128)           0         
_________________________________________________________________
conv1d_3 (Conv1D)            (None, 35, 128)           82048     
_________________________________________________________________
max_pooling1d_3 (MaxPooling1 (None, 1, 128)            0         
_________________________________________________________________
global_max_pooling1d_1 (Glob (None, 128)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 128)               16512     
_________________________________________________________________
dense_2 (Dense)              (None, 20)                2580      
=================================================================
Total params: 17,657,916
Trainable params: 247,316
Non-trainable params: 17,410,600
_________________________________________________________________
None

conv1d_1 的参数总数为 64128。但是由于 conv1d_1 是使用 filters = 128, kernel_size = 5, padding = 'valid' (这意味着没有填充)初始化的,所以参数的数量不应该是

=> kernel_size * kernel_size * num_filters + num_filters * 偏差

=> 5 * 5 * 128 + 128 * 1

=> 26 * 128

=> 3328

2个回答

实际上,您使用 1D 卷积。假设嵌入层输出的维度为 100,内核大小为 5,过滤器数量为 128,则您有 100x5x128 = 64000 个权重。加上这 128 个偏差,你得到 64128 个参数。

请注意,使用了参数共享,因此每个过滤器在深度上只有一组权重和偏差。

(5×100)+1 单个过滤器的参数数量,用于每个过滤器的偏置项,我们有 128 个这样的过滤器,因此总参数+1=501×128=64128