当数据集很小时,如何训练神经网络进行图像分类?

人工智能 深度学习 卷积神经网络 图像识别 数据集 迁移学习
2021-11-13 05:25:00

我需要训练一个卷积神经网络来对蛇图像进行分类。问题是我只有少量可用于某些蛇类型的图像。

那么,使用小数据集训练神经网络进行图像分类的最佳方法是什么?

2个回答

使用微调

您可以简单地在ImageNet上使用预训练模型,因为该数据集有多个蛇类。

然后,您可以使用自己的小数据集和输出微调模型。请参阅此以进一步了解: Fine Tuning in Keras

(如果你不使用 Keras,网上还有其他使用其他机器学习框架的教程)

这个想法只是删除最后一层(如果您使用使用 ImageNet 预训练的模型,则为 1000 个输出)并添加您选择的具有随机权重和自定义输出数量(您的类数)的层。

然后你重新训练你的网络,通常我们只重新训练最后一层(因为第一层有更多的一般特征)。

除了使用其他答案中描述的迁移学习外,您还应该考虑使用连体网络。这种类型的网络用于当一个人没有提出很多他想要区分的对象的例子的情况下。一般的想法是,不要“告诉”网络“这是一条眼镜蛇”,而是提供如下信息:“这是一条眼镜蛇,那是一条响尾蛇,学习区别”。

有一个专门针对您的问题的主题,称为一次性学习。

看看这个教程: https ://hackernoon.com/one-shot-learning-with-siamese-networks-in-pytorch-8ddaab10340e