我有一个包含 3 个类的数据集,其中包含以下项目:
- 第 1 类:900 个元素
- 第 2 类:15000 个元素
- 第 3 类:800 个元素
我需要预测第 1 类和第 3 类,这表明与规范有重大偏差。2 类是我不关心的默认“正常”情况。
我会在这里使用什么样的损失函数?我正在考虑使用 CrossEntropyLoss,但由于存在类不平衡,我想这需要加权吗?这在实践中如何运作?像这样(使用 PyTorch)?
summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)
还是应该把重量倒过来?即1 /重量?
这是开始的正确方法还是我可以使用其他/更好的方法?
谢谢