缩放渐变有什么作用?

数据挖掘 机器学习 机器学习模型 深思熟虑
2021-10-07 01:38:47

MuZero 论文伪代码中,它们具有以下代码行:

hidden_state = tf.scale_gradient(hidden_state, 0.5)

这是做什么的?为什么会在那里?

我已经搜索过tf.scale_gradient,它在 tensorflow 中不存在。而且,与 不同scalar_loss的是,他们似乎没有在自己的代码中定义它。

对于上下文,这是整个函数:

def update_weights(optimizer: tf.train.Optimizer, network: Network, batch,
                   weight_decay: float):
  loss = 0
  for image, actions, targets in batch:
    # Initial step, from the real observation.
    value, reward, policy_logits, hidden_state = network.initial_inference(
        image)
    predictions = [(1.0, value, reward, policy_logits)]

    # Recurrent steps, from action and previous hidden state.
    for action in actions:
      value, reward, policy_logits, hidden_state = network.recurrent_inference(
          hidden_state, action)
      predictions.append((1.0 / len(actions), value, reward, policy_logits))

      # THIS LINE HERE
      hidden_state = tf.scale_gradient(hidden_state, 0.5)

    for prediction, target in zip(predictions, targets):
      gradient_scale, value, reward, policy_logits = prediction
      target_value, target_reward, target_policy = target

      l = (
          scalar_loss(value, target_value) +
          scalar_loss(reward, target_reward) +
          tf.nn.softmax_cross_entropy_with_logits(
              logits=policy_logits, labels=target_policy))

      # AND AGAIN HERE
      loss += tf.scale_gradient(l, gradient_scale)

  for weights in network.get_weights():
    loss += weight_decay * tf.nn.l2_loss(weights)

  optimizer.minimize(loss)

缩放渐变有什么作用,为什么他们在那里这样做?

2个回答

本文的作者在这里 - 我错过了这显然不是 TensorFlow 函数,它相当于 Sonnet 的scale_gradient或以下函数:

def scale_gradient(tensor, scale):
  """Scales the gradient for the backward pass."""
  return tensor * scale + tf.stop_gradient(tensor) * (1 - scale)

鉴于它的伪代码?(因为它不在 TF 2.0 中)我会使用梯度裁剪批量归一化(“激活函数的缩放”)