我将如何改进我的 CNN 模型(Keras)?

数据挖掘 机器学习 神经网络 深度学习 张量流 美国有线电视新闻网
2022-02-19 12:15:24

最近,我阅读了一篇关于使用面部图像进行年龄检测的研究论文。所以现在正因为如此,我正试图通过将 CNN 应用于面部图像数据集(以及他们各自的年龄)来预测他们的年龄(例如 0-10, 11-20、21-30...)。

用于训练和测试

training.shape (50000, 28, 28)
testing.shape (2938, 28, 28)

我试图保持图像小,因为它们可以运行得更快以及使用灰度。对于实际的图层本身,我尽量保持简单,现在,

model = Sequential()
model.add(Conv2D(64, kernel_size=3, activation='relu', input_shape=(28,28,1)))
model.add(Conv2D(32, kernel_size=3, activation='relu'))
model.add(Flatten())
model.add(Dense(10, activation='softmax'))


Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 26, 26, 64)        640       
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 24, 24, 32)        18464     
_________________________________________________________________
flatten_1 (Flatten)          (None, 18432)             0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                184330    
=================================================================
Total params: 203,434
Trainable params: 203,434
Non-trainable params: 0
_________________________________________________________________

编译如下

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

到目前为止,运行 100 个 epoch 后的最佳准确度是 37.16。这不是很好,但最近我可以访问我的一所学校的 gpu,所以我想修复我做错的任何事情并改进我的模型。在改进模型方面有什么可以推荐的吗?这可能是我第一次尝试这样做。

2个回答

您应该注意到的第一件事是您几乎破坏了输入信号。看一看28×28一张脸的图像?你能看见什么?青少年和中年人有什么区别吗?关键是网络应该使用没有高贝叶斯误差的数据进行训练,这意味着您作为专家可以区分输入并正确标记它们。增加输入的大小。通过这样做,如果你使用当前的机制,你可能在密集层和卷积层之间有很多可训练的参数。因此,尝试在其中使用更多的卷积层和一些池化层。此外,尝试添加更密集的层,每个层都有更多的神经元。

我想到了两件事,你可以快速尝试。

  1. 不要缩小图像,28x28 太小(对于类似 MNIST 的数据集很好,但对于人脸则不行)。此外,通过将它们设为灰度,您会丢失很多有价值的信息,请尝试使用彩色图像。
  2. 使用预训练的 CNN,keras 提供了其中的一些,我通常会使用 VGG16,因为它是一个简单的可重用网络。我的建议是冻结除最后一层之外的所有层,并查看您获得的性能(作为基准)。然后考虑解冻其他层以提高性能。

请尝试一下这些选项,根据 GPU,这对于 CNN 来说几乎是必须的,请注意 kaggle.com 现在正在为 jupyter notebooks 提供免费 GPU(它在 beta atm 上,但似乎可以完成这项工作很好)。