梯度计算

数据挖掘 火炬
2022-03-09 03:55:43

我是数据科学的初学者。我正在尝试使用自定义 autograd 函数来理解这个用于梯度计算的 PyTorch 代码:

class MyReLU(torch.autograd.Function):

@staticmethod
def forward(ctx, x):

    ctx.save_for_backward(x)
    return x.clamp(min=0)

def backward(ctx, grad_output):

    x, = ctx.saved_tensors
    grad_x = grad_output.clone()
    grad_x[x < 0] = 0
    return grad_x

但是,我不明白这一行:grad_x[x < 0] = 0谁能解释这部分?

1个回答

您找到的示例是计算ReLU函数的梯度,其梯度为

ReLU(x)={1if x>00if x<0

因此,当 时x<0,您将梯度设为 0 grad_x[x < 0] = 0