为什么 Pytorch 中的 batchnorm1d 使用以下示例(2 行代码)计算 0?

数据挖掘 火炬 批量标准化
2022-02-28 09:27:27

这是代码

import torch
import torch.nn as nn
x = torch.Tensor([[1, 2, 3], [1, 2, 3]])
print(x)
batchnorm = nn.BatchNorm1d(3, eps=0, momentum=0)
print(batchnorm(x))

这是打印的内容

tensor([[1., 2., 3.],
        [1., 2., 3.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]], grad_fn=<NativeBatchNormBackward>)

我期待的是以下内容:

使用手工计算,让x=(1,2,3), 然后E(x)=(1+2+3)/3=2Var(x)=(12+22+32)/3(2)2=0.9999...,所以最终输出看起来像y(1,2,3)2/1=(1,0,1)

所以,我期望批处理规范的输出是

tensor([[-1., 0., 1.],
        [-1., 0., 1.]])

有人可以解释我哪里出错了吗?

1个回答

BatchNorm 沿 dim = 0 工作。您可能想LayerNorm改用。