我通过使用 torch.where 给不相关的条目一个非常大的负值来解决这个问题,以便它们在 exp 之后消失。我还利用了 CEL 技巧的日志。
loss = torch.zeros(1).type(dtype)
states = torch.cat([states_[_] for _ in idc[:self.batch_size]]).type(dtype)
scores = torch.tensor([scores_[_] for _ in idc[:self.batch_size]]).type(dtype).view(-1, 1)
policies = torch.cat([policies_[_] for _ in idc[:self.batch_size]]).type(dtype)
values, logits = self.network(states)
value_loss = mse(values, scores)
p = torch.nn.functional.normalize(policies, dim=[1,2,3], p=1)
logit_exp = torch.exp(torch.where(p > 0, logits, inf))
s = torch.log(torch.sum(logit_exp, dim=[1,2,3])).view(-1, 1, 1, 1)
log_q = logits - torch.mul(ones, s)
policy_loss = torch.sum(torch.mul(p, -log_q))/self.batch_size
loss = value_loss + policy_loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
idc = idc[self.batch_size:]