当预测已经有概率时,Pytorch 会进行交叉熵损失

数据挖掘 神经网络 损失函数 可能性 火炬 软最大
2021-09-25 03:22:04

因此,通常可以使用 PyTorch 中的交叉熵损失函数或通过将 logsoftmax 与负对数似然函数相结合来应用分类交叉熵,如下所示:

m  = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
pred   = torch.tensor([[-1,0,3,0,9,0,-7,0,5]], requires_grad=True, dtype=torch.float)
target = torch.tensor([4])
output = loss(m(pred), target)
print(output)

事情是。如果输出的数据已经处于具有变量 pred 已经具有概率的概率状态,该怎么办。数据显示如下:

pred = torch.tensor([[.25,0,0,0,.5,0,0,.25,0]], requires_grad=True, dtype=torch.float)

那么如何在 PyTorch 中完成交叉熵呢?

1个回答

您可以自己很容易地实现分类交叉熵。它计算为

cross-entropy=1ni=0nj=0myijlogy^ij

在哪里 n 是您批次中的样品数量, m 是类的数量, yi 例如是 one-hot 目标 i, y^i 是预测的概率分布,并且 yij 指的是 j- 这个数组的第一个元素。

在 PyTorch 中:

def categorical_cross_entropy(y_pred, y_true):
    y_pred = torch.clamp(y_pred, 1e-9, 1 - 1e-9)
    return -(y_true * torch.log(y_pred)).sum(dim=1).mean()

然后,您可以categorical_cross_entropyNLLLoss在模型训练中一样使用。我们有这torch.clamp条线的原因是为了确保我们没有零元素,这将导致torch.log产生nanor inf

您必须在代码中做出的一个区别是,此版本需要一个单热目标而不是整数目标。您可以像这样轻松转换当前的目标列表:

one_hot_targets = torch.eye(NUM_CLASSES)[targets]

其中targetsa 是torch.tensor整数值,NUM_CLASSES是您拥有的输出类的数量。