REINFORCE 算法的更新 - 逐步或逐集?

人工智能 强化学习 加强
2021-11-15 12:13:16

REINFORCE 是一种蒙特卡洛策略梯度算法,它通过生成情节来更新策略网络的权重(参数)。这是 Sutton 书中的伪代码(与 Silver 的 RL 注释中的方程式相同):

加强

当我尝试用自己的问题来实现这一点时,我发现了一些奇怪的东西。以下是Pytorch 官方 GitHub 上的实现:

def finish_episode():
    R = 0
    policy_loss = []
    returns = []
    for r in policy.rewards[::-1]:
        R = r + args.gamma * R
        returns.insert(0, R)
    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + eps)
    for log_prob, R in zip(policy.saved_log_probs, returns):
        policy_loss.append(-log_prob * R)
    optimizer.zero_grad()
    policy_loss = torch.cat(policy_loss).sum()
    policy_loss.backward()
    optimizer.step()
    del policy.rewards[:]
    del policy.saved_log_probs[:]

感觉上面这两个还是有区别的。在 Sutton 的伪代码中,算法更新θ一步 t,而第二个代码(PyTorch 的那个)累积损失并更新θ加上总和,即在每一集之后我试图搜索 REINFORCE 的其他实现,我发现大多数实现都遵循第二种形式,在每个生成的剧集之后更新。

为了检查两者是否给出相同的结果,我将第二个代码更改为

def finish_episode():
    R = 0
    policy_loss = []
    returns = []
    for r in policy.rewards[::-1]:
        R = r + args.gamma * R
        returns.insert(0, R)
    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + eps)
    for log_prob, R in zip(policy.saved_log_probs, returns):
        optimizer.zero_grad()
        loss = -log_prob * R
        loss.backward()
        optimizer.step()

...

并运行它,它会给出不同的结果(如果我的代码没有问题)。所以它们是不一样的,我认为最后一个更接近于REINFORCE的原始伪代码。我现在缺少什么?可以因为结果大致相同吗?(我不确定这个说法)

但是,从某种意义上说,我认为 Pytorch 的实现是 REINFORCE 的正确版本。在萨顿的伪代码中,情节是首先生成的,所以我认为θ不应在每一步更新,应在计算总损失后更新。如果θ在每一步都更新,那么这样θ可能和原版不一样θ用于生成剧集的。

1个回答

嗨 Seewoo Lee,欢迎来到我们的社区!

您观察的本质是 Sutton 版本的 REINFORCE 考虑了所有轨迹来计算回报,而在 pytorch 版本中只考虑未来,因此反向计算未来奖励并忽略以前的奖励. 结果是未来的行动不会因为早期的错误而受到惩罚。OpenAI 的人将此称为“reward-to-go”,但就我个人而言,我发现它类似于 Sutton 书中的 Monte Carlo On Policy Control without Exploring Starts 或 First Visit。

您可以在Spinning Up RL:第 3 部分:策略梯度简介 - 不要让过去分散您的注意力中找到更多关于 REINFORCE 和 Policy Gradient 的信息

另外,需要注意的是,即使在 Sutton 的版本中,整个轨迹都是展开的,即情节完成,然后权重得到更新。否则,它不再是蒙特卡洛方法,而是成为 TD 方法。另外,由于采样是不可微的操作,因此不能对单个点进行更改,而是通过收集大量轨迹来估计梯度。