如何在pytorch中实现平方铰链损失

数据挖掘 美国有线电视新闻网 损失函数 火炬 火炬 铰链损失
2022-02-14 16:01:45

有没有人对如何实现这种损失以将其与卷积神经网络一起使用有任何建议?另外,我应该如何编码我的训练数据的标签?我们之前使用了一种带有 bce 损失的热编码,我在徘徊是否应该为铰链损失保持这种方式,因为标签本身没有用于损失的公式中,除了指示哪个是真正的类别. 顺便说一下,数据集是 CIFAR100。先感谢您!

编辑:我实现了这个损失的一个版本,问题是在第一个 epoch 之后损失总是为零,所以训练不会更进一步。这是代码:

class MultiClassSquaredHingeLoss(nn.Module):
    def __init__(self):
        super(MultiClassSquaredHingeLoss, self).__init__()

    def forward(self, output, y): #output: batchsize*n_class
        n_class = y.size(1)
        #margin = 1 
        margin = 1
        #isolate the score for the true class
        y_out = torch.sum(torch.mul(output, y)).cuda()
        output_y = torch.mul(torch.ones(n_class).cuda(), y_out).cuda()
        #create an opposite to the one hot encoded tensor
        anti_y = torch.ones(n_class).cuda() - y.cuda()
        
        loss = output.cuda() - output_y.cuda() + margin
        loss = loss.cuda()
        #remove the element of the loss corresponding to the true class
        loss = torch.mul(loss.cuda(), anti_y.cuda()).cuda()
        #max(0,_)
        loss = torch.max(loss.cuda(), torch.zeros(n_class).cuda())
        #squared hinge loss
        loss = torch.pow(loss, 2).cuda()
        #sum up
        loss = torch.sum(loss).cuda()
        loss = loss / n_class        
        
        return loss
1个回答

一种选择是使用现有的torch.nn.MultiMarginLoss对于平方损失,设置p=2