什么损失函数用于不平衡类(使用 PyTorch)?

数据挖掘 神经网络 火炬
2021-10-12 21:00:03

我有一个包含 3 个类的数据集,其中包含以下项目:

  • 第 1 类:900 个元素
  • 第 2 类:15000 个元素
  • 第 3 类:800 个元素

我需要预测第 1 类和第 3 类,这表明与规范有重大偏差。2 类是我不关心的默认“正常”情况。

我会在这里使用什么样的损失函数?我正在考虑使用 CrossEntropyLoss,但由于存在类不平衡,我想这需要加权吗?这在实践中如何运作?像这样(使用 PyTorch)?

summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)

还是应该把重量倒过来?即1 /重量?

这是开始的正确方法还是我可以使用其他/更好的方法?

谢谢

1个回答

我会在这里使用什么样的损失函数?

交叉熵是分类任务的首选损失函数,无论是平衡的还是不平衡的。当尚未从领域知识构建偏好时,它是首选。

我想这需要加权吗?这在实践中如何运作?

是的。班级重量C 是最大类的大小除以类的大小 C.

例如,如果第 1 类有 900 个,第 2 类有 15000 个,第 3 类有 800 个样本,那么它们的权重分别为 16.67、1.0 和 18.75。

您还可以使用最小的类作为提名者,分别给出 0.889、0.053 和 1.0。这只是重新缩放,相对权重是相同的。

这是开始的正确方法还是我可以使用其他/更好的方法?

是的,这是正确的方法。

编辑

感谢@Muppet,我们还可以使用类过采样,相当于使用类权重这是通过WeightedRandomSampler在 PyTorch 中使用相同的上述权重来完成的。