如何在非常不平衡的数据集上提高神经网络的准确性?

数据挖掘 机器学习 深度学习
2022-03-02 01:10:24

我有一个数据集,其中包含有关事故的数据。该数据集包含大约 15.000 个条目,我无法获得更多。分布如下:

  • 88.6% 的数据为 1 级事故
  • 10.6% 的数据为 2 级事故
  • 0.8% 的数据为 3 级事故

如您所见,训练数据的最大部分属于一个类。我只有很少的 3 级事故示例(15.000 行中的大约 100 行),但正确分类 3 级事故是最重要的。

我在数据上训练了一个非常标准的深度神经网络,并在验证集上获得了约 93% 的准确度。我使用了一个带有 AdamOptimizer 的自定义 Tensorflow 估计器,并尽可能地调整了参数。问题是,网络仍然将大多数事故归类为 1 类事故。因此,如果我在验证集中有 25 个 3 类事故,网络会将其中 10 个错误分类为 1 类事故。我想改进它。

在这种情况下有什么方法可以提高性能吗?显而易见的选择是获取更多 3 级事故的数据,但遗憾的是这是不可能的。多次显示现有的 3 类数据是否有意义?例如,用所有数据训练 5 个 Epoch,然后只用 3 级事故训练 3 个额外的 Epoch?

或者我可以在数据预处理期间做些什么吗?现在我正在 MinMax-Scaling 输入数据以达到 [0, 1] 区间。有没有其他方法可以更多地强调异常值?(如果您假设异常值大多属于第 3 类)

我希望有人知道一些方法来提高这种情况下的准确性。

编辑:数据集主要有分类列,如:

  • 街道等级(例如高速公路或乡村道路)
  • 轻(例如“好”)
  • 天气(例如“下雨”或“晴天”)
  • ...

另外它有这些列:

  • 事故日期(仅月和日)
  • 年龄
  • 一天中的时间
  • 受伤人数
  • 车辆数量

所以一个条目可能看起来像这样:

{
 street_class: 'highway',
 light: 'daylight',
 date: '23. Jan',
 age: 59,
 injured_persons: 2,
 vehicles: 2,
 time: 1724,
 label: 1
 ...
}
1个回答

我会尝试某种数据增强,但是从您的问题来看,您并不清楚您拥有什么类型的数据,也无法提出解决方案。

尝试在您的问题中添加数据示例。