pytorch中的负学习实现

数据挖掘 损失函数 火炬
2022-01-27 18:11:24

我读过一篇关于消极学习的论文:https ://arxiv.org/abs/1908.07387 。这个想法是,您不仅可以通过告诉样本的标签是什么来训练网络,还可以通过告诉它肯定不是什么来训练网络。我们称后者为“负面”标签

该论文的摘录说(顶部公式用于通常的“正”标签损失(PL),底部 - 用于“负”标签损失(NL):

论文摘录

我有一个问题,“负面”标签收集比标记每个样本要容易得多。所以很想使用它。

在pytorch中是否有这种损失函数的实现?还是我应该编写自定义损失层代码?如果是这样,我该怎么做?

1个回答

一种实现是

( (loss+loss_neg) / (float((labels>=0).sum())+float((labels_neg[:,0]>=0).sum())) ).backward()

来自NLNL-Negative-Learning-for-Noisy-Labels GitHub 存储库