在MuZero 论文伪代码中,它们具有以下代码行:
hidden_state = tf.scale_gradient(hidden_state, 0.5)
这是做什么的?为什么会在那里?
我已经搜索过tf.scale_gradient,它在 tensorflow 中不存在。而且,与 不同scalar_loss的是,他们似乎没有在自己的代码中定义它。
对于上下文,这是整个函数:
def update_weights(optimizer: tf.train.Optimizer, network: Network, batch,
weight_decay: float):
loss = 0
for image, actions, targets in batch:
# Initial step, from the real observation.
value, reward, policy_logits, hidden_state = network.initial_inference(
image)
predictions = [(1.0, value, reward, policy_logits)]
# Recurrent steps, from action and previous hidden state.
for action in actions:
value, reward, policy_logits, hidden_state = network.recurrent_inference(
hidden_state, action)
predictions.append((1.0 / len(actions), value, reward, policy_logits))
# THIS LINE HERE
hidden_state = tf.scale_gradient(hidden_state, 0.5)
for prediction, target in zip(predictions, targets):
gradient_scale, value, reward, policy_logits = prediction
target_value, target_reward, target_policy = target
l = (
scalar_loss(value, target_value) +
scalar_loss(reward, target_reward) +
tf.nn.softmax_cross_entropy_with_logits(
logits=policy_logits, labels=target_policy))
# AND AGAIN HERE
loss += tf.scale_gradient(l, gradient_scale)
for weights in network.get_weights():
loss += weight_decay * tf.nn.l2_loss(weights)
optimizer.minimize(loss)
缩放渐变有什么作用,为什么他们在那里这样做?