Nesterov 加速梯度的顺序正确吗?

数据挖掘 梯度下降 反向传播
2022-02-15 23:06:24

我已经实现了 Nesterov Accelerated Gradient (NAG)(链接,部分:“Nesterov Accelerated Gradient”),但不确定它是否可行。

vt=γvt1+ηθJ(θγvt1)
θ=θvt

我尝试了以下操作(考虑一个仅包含 1 个训练案例的小批量):

  1. 运行前向传递,并得到错误
  2. 将缩放的先前动量 ( momentum_tMinOne * coeff) 应用于权重
  3. 使用每次传递的错误运行反向传递。
  4. 设置momentum_tMinOne = currentGradient,但不要将其应用于权重
  5. 重复前传,从 1 开始。

我是否需要记住类似的东西momentum_tMinTwo,或者仅仅保持以前的动力就足够了?

我是否正确,这样我们不需要为动量保留两个矩阵,而只需要momentum_tMinOne矩阵?

1个回答

标准动量具有以下步骤:

  1. 马上,重新计算新的动量:

    μt+1=μt(decayScalar)+(learnRate)

  2. 通过这个新动量调整权重

    θt+1:=θtμt+1


涅斯捷罗夫动量是这样的:

  1. 大跃进:任意校正权重μ到目前为止,我们拥有:

    θt+1:=θtμ(decayScalar)

  2. 计算梯度从新的权重θt+1

  3. 通过这个梯度校正这些权重(现在没有任何动量):

    θt+2:=θt+1(learnRate)

  4. 最后,重新计算动量如下:

    μ:=θt+2θt

因此,动量在最后更新。它从“大跳跃前的权重”变成一个向量,指向“新鲜梯度校正后的权重”。

参考:Geoffrey Hinton Lecture 6C Corsera


重新安排:

为了避免将梯度计算停留在优化器函数的中间(步骤 2 和 3),我们可以重新安排如下:

  1. 计算我们目前拥有的权重的梯度。

  2. 通过梯度校正这些权重(现在没有任何动量),如下所示:

    θt+1:=θt(learnRate)

  3. 更新动量:

    μ:=θt+1θcached
    θcached:=θt+1

  4. 大跳跃

    θt+2:=θt+1μ(decayScalar)

请注意,这样第 2、3、4 步都在我们的优化器中。我们可以在优化器之外(在步骤 1 中)计算梯度,使我们的代码更具可读性:)

    size_t _numApplyCalled = 0;
      
    //Nesterov Accelerated Gradient.
    //Placed at the end of a backprop, should be followed by a usual forward-propagation
    // https://datascience.stackexchange.com/a/26395/43077
    //
    void apply( float *toChange,  float *newGrad,  size_t count ){

        float learnRate = get(OptimizerVar::LEARN_RATE);//scalar
        float momentumCoeff = get(OptimizerVar::MOMENTUM_1);//scalar

        const bool isFirstEver_apply = _numApplyCalled == 0;

        for (int i=0; i<_arraySize; ++i){
                //correction by gradient alone:
                toChange[i]  -= newGrad[i]*learnRate;

                // determine momentum:
                if (isFirstEver_apply){//nothing was cached yet.
                    _momentumVals[i] = 0.0f;
                }
                else {
                    _momentumVals[i] = toChange[i] - _weightsCached[i];
                }
                //caching, AFTER momentum calc, but BEFORE the jump:
                _weightsCached[i] = toChange[i];
            
                //jump:
                toChange[i] -= _momentumVals[i] * momentumCoeff;
        }//end for

        ++_numApplyCalled;//increments by 1
    }