如何在pytorch中使用交叉熵损失进行二进制预测?

数据挖掘 深度学习
2021-09-17 08:51:27

在 pytorch 文档中,它说交叉熵损失:

输入必须是大小为 (minibatch, C) 的张量

这是否意味着对于二进制 (0,1) 预测,输入必须转换为第二维等于 (1-p) 的 (N,2) 张量?

因此,例如,如果我预测目标为 1(真)的类的值为 0.75,我是否必须将两个值(0.75;0.25)堆叠在一起作为输入?

3个回答

实际上没有必要这样做。PyTorch 具有代表二元交叉熵损失的 BCELoss。请在此处查看原始文档这是一个简单的例子:

m = nn.Sigmoid() # initialize sigmoid layer
loss = nn.BCELoss() # initialize loss function
input = torch.randn(3, requires_grad=True) # give some random input
target = torch.empty(3).random_(2) # create some ground truth values
output = loss(m(input), target) # forward pass
output.backward() # backward pass

在下面给出的示例中,3 是批量大小,2 将是给定示例中每个类的概率。

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 2, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(2)
output = loss(input, target)

如果您的输入可以属于两个不同类别之一,您的网络必须有两个最终神经节点然后目标张量必须是形状(小批量)并且只包含零和一。

通常,您的网络必须输出形状为 (minibatch, C) 的张量,其中 C 是您的数据可以分类的类别数。目标张量必须是形状 (minibatch) 并且仅包含 long 类型的数字,即集合 {0, ..., C-1} 的元素。

也许是一个现实世界的例子:你的网络充满了狗、猫和猪的图片。因此它有 3 个最终节点,它们(通常)表示输入图像显示狗(0 类)、猫(1 类)或猪(2 类)的概率。假设您一次将 10 张图像放入网络(小批量 = 10)。那么你的目标张量可能是:

torch.LongTensor([0, 2, 1, 0, 1, 0, 2, 2, 1, 0, 0, 1])

这可以解释为您的批次的第一张图片显示一只狗,第二张图片显示一只猪,第三张图片显示一只猫,......