优先重放,重要性采样到底做了什么?

数据挖掘 强化学习
2021-09-15 03:12:21

我无法理解Prioritized Replay (page 5)中重要性采样权重 (IS) 的目的。

转换的“成本”越大,就越有可能从经验回放中采样。我的理解是,“IS”有助于在我们训练足够长的时间后顺利放弃使用优先回放。 但是我们改用什么,统一采样?

我想我无法意识到这样一个系数中的每个分量是如何影响结果的。有人可以用文字解释吗?

wi=(1N1P(i))β

然后它被用来抑制渐变,我们试图从过渡中获得。

在哪里:

  • wi 伊斯兰国”
  • N 是体验重放缓冲区的大小
  • P(i) 是选择转移的机会 i,取决于“它的成本有多胖”。
  • β 从 0.4 开始,随着每个新的 epoch 越来越接近 1。

我对这些参数的理解是否也正确?

编辑在接受答案后的某个时候,我发现了一个额外的来源,一个可能对初学者有帮助的视频 - MC Simmmulations: 3.5 Importance Sampling


编辑正如@avejidah 在对他的回答的评论中所说的“1/N 用于通过样本被采样的概率对样本进行平均”

要了解为什么它很重要,假设β固定为 1,我们有4 个样本,每个样本都有P(i)如下:

0.1  0.2   0.3     0.4

也就是说,第一个条目有 10% 被选中,第二个是 20%,依此类推。现在,反转它们,我们得到:

 10   5    3.333   2.5

通过平均1/N(在我们的例子中是1/4) 我们得到:

2.5  1.25  0.8325  0.625     ...which would add up to '5.21'

正如我们所看到的,它们比简单的反转版本更接近于零(10,5,3.333,2.5)。这意味着我们的网络的梯度不会被放大太多,从而在我们训练我们的网络时导致更少的方差。

所以,没有这个1N我们是否幸运地选择了最不可能的样本(0.1),梯度将被缩放 10 倍。更小的值会更糟,比如说0.00001机会,如果我们的经验回放有数千个条目,这是很常见的。

换句话说,1N只是为了让你的超参数(例如学习率)不需要调整,当你改变你的经验回放缓冲区的大小时。

2个回答

DQN 本质上存在不稳定性。在最初的实现中,采用了多种技术来提高稳定性:

  1. 目标网络与落后于训练模型的参数一起使用;
  2. 奖励被限制在 [-1, 1] 范围内;
  3. 渐变被裁剪到 [-1, 1] 范围内(使用 Huber Loss 或渐变裁剪);
  4. 与您的问题最相关的是,使用大型重播缓冲区来存储转换。

继续第 4 点,使用来自大型重放缓冲区的完全随机样本有助于解相关样本,因为它同样可能从过去的数十万个情节中采样转换,因为它对新的情节进行采样。但是,当优先级抽样被添加到混合中时,纯粹的随机抽样就被抛弃了:显然存在对高优先级样本的偏见。为了纠正这种偏差,与高优先级样本相对应的权重几乎没有调整,而与低优先级样本相对应的权重保持相对性不变。

直觉上这应该是有道理的。具有高优先级的样本可能会在训练中多次使用。减少这些常见样本的权重基本上告诉网络,“对这些样本进行训练,但没有过多强调;它们很快就会再次出现。” 相反,当看到一个低优先级的样本时,IS 权重基本上告诉网络,“这个样本很可能永远不会再被看到,所以完全更新。” 请记住,这些低优先级样本无论如何都具有低 TD 误差,因此可能没有太多可以从中学到的东西;但是,出于稳定性目的,它们仍然很有价值。

在实践中,beta 参数在训练期间被退火到 1。可以同时对 alpha 参数进行退火,从而使优先采样更积极,同时更强烈地校正权重。在实践中,从您链接的论文中,保持固定的 alpha (.6) 同时将 beta 从 0.4 退火到 1 似乎是基于优先级的采样的最佳选择(第 14 页)。

作为旁注,根据我自己的个人经验,简单地忽略 IS 权重(即根本不纠正)会导致网络一开始训练得很好,但随后网络似乎过度拟合,忘记了所学的内容(也称为灾难性遗忘) ,和坦克。例如,在 Atari Breakout 中,平均在前 5000 万帧左右的帧中增加,然后平均值完全下降。您链接的论文对此进行了一些讨论,并提供了一些图表。

我有个疑问。作为 PER 论文,

出于稳定性原因,我们总是将权重标准化 1/maxi wi,这样它们只会向下缩放更新

那么1/N因子不是失效了吗?例如,考虑最后一个样本,

case 1 without N : 0.25/10 = 0.25
case 2 with N=4; 0.625/2.5 = 0.25.

所以,

Wi = pow(N,-beta) * pow(Pi, -beta)
Wmax = pow(N,-beta) * pow(Pmin,-beta)

通过规范化,

Wi/Wmax will cancel out the pow(N, -beta).

如果我的理解有误,请帮助我。