批量归一化如何计算训练后的总体统计信息?

机器算法验证 机器学习 神经网络 深度学习 卷积神经网络 批量标准化
2022-04-07 00:45:55

我正在阅读批量标准化 (BN) 论文(1),它说:

为此,一旦网络经过训练,我们使用人口而不是 mini 的标准化\ hat { -批次,统计。

x^=xE[x]Var[x]+ϵ

我的问题是,它如何计算这个人口统计数据以及在什么训练集(测试、验证、训练)上?我以为我知道这意味着什么,但一段时间后,我意识到我不确定它是如何计算的。我假设它试图估计真实的均值和方差,尽管我不确定它是如何做到的。我可能会根据整个数据集计算均值和方差,并使用这些时刻进行推理。

然而,让我怀疑我错的是他们在同一节后面关于无偏方差估计的讨论:

我们使用无偏方差估计期望超过训练的mini-batches是它们的样本方差。Var[x]=mm1EB[σB2]mσB2

由于我们谈论的是人口统计数据,因此对这篇论文的评论感觉就像(对我而言)不知从何而来,并且不确定他们在说什么。他们只是(随机)澄清他们在训练期间使用无偏估计还是使用无偏估计来计算总体统计?


1:Ioffe S. 和 Szegedy C. (2015),
“批量标准化:通过减少内部协变量偏移来加速深度网络训练”,
第 32 届机器学习国际会议论文集,法国里尔,2015 年
。机器学习研究杂志: W&CP 第 37 卷

2个回答

通常,人口统计数据来自训练集。如果您包含测试集,则在测试时,您将获得技术上不应该访问的信息(有关整个数据集的信息)。出于同样的原因,验证集不应用于计算这些统计数据。

请记住,由于批标准化不仅在输入层,随着网络学习和更改其参数(因此,其在每一层的输出),总体的统计数据将因时代而异。

因此,计算这些统计数据的常用方法是在训练期间保持(指数衰减或移动)平均值。这将消除由于小批量训练引起的随机变化,并保持最新的学习状态。您可以在批处理规范的火炬代码中看到一个示例:https ://github.com/torch/nn/blob/master/lib/THNN/generic/BatchNormalization.c#L22

该论文提到他们使用移动平均值而不是仅保留最后计算的统计数据:

相反,使用移动平均线,我们可以在模型训练时跟踪模型的准确性。

对于您的第二个问题,他们说他们使用该无偏估计来估计总体方差(以供将来推断)。

在推理中,即使您的批量大小不是一个,批量规范也基于整个训练集(使推理稳定)。