pytorch中的批处理CrossEntropyLoss

数据挖掘 逻辑回归 损失函数 火炬
2022-02-16 19:20:04

我想知道如何使用 pytorch 内置插件来实现这一点。我有一个称为策略的 3 维 uint 输入。大多数条目为零,如果我要对其进行 L1 标准化,我将有一个(目标)概率分布。

我还得到了一个线性层的输出,称为“logit”,其形状与“policy”相同。我必须通过采用 softmax 将其转换为概率分布,但仅限于 policy 为 non-zero 的条目

然后损失为 -sum(log(logit_masked_softmax) * policy_normalized))

我已经使用布尔索引使用 nn.functional 模块手动实现了这一点。问题是我想分批执行此操作,其中 4 维张量代表 3 维输入的批次。我相信必须有一种内置的方法来实现这一点,而且它可能也更快,数值更稳定。

1个回答

我通过使用 torch.where 给不相关的条目一个非常大的负值来解决这个问题,以便它们在 exp 之后消失。我还利用了 CEL 技巧的日志。

loss = torch.zeros(1).type(dtype)

states = torch.cat([states_[_] for _ in idc[:self.batch_size]]).type(dtype)
scores = torch.tensor([scores_[_] for _ in idc[:self.batch_size]]).type(dtype).view(-1, 1)
policies = torch.cat([policies_[_] for _ in idc[:self.batch_size]]).type(dtype)

values, logits = self.network(states)
value_loss = mse(values, scores)

p = torch.nn.functional.normalize(policies, dim=[1,2,3], p=1)
logit_exp = torch.exp(torch.where(p > 0, logits, inf))
s = torch.log(torch.sum(logit_exp, dim=[1,2,3])).view(-1, 1, 1, 1)
log_q = logits - torch.mul(ones, s)
policy_loss = torch.sum(torch.mul(p, -log_q))/self.batch_size

loss = value_loss + policy_loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
idc = idc[self.batch_size:]