Batch Normalization 如何以及为什么在训练时使用移动平均值来跟踪模型的准确性?

机器算法验证 机器学习 神经网络 卷积神经网络 批量标准化
2022-01-19 01:10:40

我正在阅读批量归一化 (BN) 论文(1),但不明白需要使用移动平均线来跟踪模型的准确性,即使我接受这是正确的做法,我也不明白他们到底在做什么。

据我了解(这可能是错误的),该论文提到一旦模型完成训练,它就会使用总体统计数据而不是小批量统计数据。在讨论了一些无偏估计之后(这对我来说似乎是切线的,我不明白为什么它会谈论这个),他们会说:

相反,我们使用移动平均线来跟踪模型训练时的准确性。

这是让我感到困惑的部分。为什么他们使用移动平均线来估计模型的准确性以及在哪些数据集上?

通常人们会做什么来估计他们的模型的泛化,他们只是跟踪他们的模型的验证错误(并可能提前停止他们的梯度下降以进行正则化)。然而,批量标准化似乎正在做一些完全不同的事情。有人可以澄清什么以及为什么它在做不同的事情吗?


1 : Ioffe S. 和 Szegedy C. (2015),
“批标准化:通过减少内部协变量偏移来加速深度网络训练”,
第 32 届机器学习国际会议论文集,法国里尔,2015 年
。机器学习研究杂志: W&CP 第 37 卷

3个回答

使用 batch_normalization 时,我们首先要了解的是,它在Training 和 Testing中以两种不同的方式工作。

  1. 在训练中,我们需要计算小批量平均值以标准化批次

  2. 在推理中,我们只应用预先计算的小批量统计信息

所以第二件事是如何计算这个小批量静态

移动平均线来了

running_mean = momentum * running_mean + (1 - momentum) * sample_mean
running_var = momentum * running_var + (1 - momentum) * sample_var

他们正在谈论批量标准化,他们已经描述了训练过程,但没有描述推理。

这是使用样本均值等对隐藏单元进行归一化的过程。

在本节中,他们解释了在您刚刚进行预测时(即训练完成后)在推理阶段要做什么。

但是,在停止验证中,您将验证集的预测与训练交织在一起,以估计您的验证错误。

因此,在此过程中,您没有总体平均值(在您训练时平均值仍在变化),因此您使用运行平均值来计算批范数参数来计算验证集的性能。

正是在这个意义上

相反,我们使用移动平均线来跟踪模型训练时的准确性。

与字面上使用运行方式作为神经网络性能的指标无关。

在您引用的论文中,建议的测试时间行为是使用大量训练图像而不是使用运行平均值来计算每个特征的样本均值和方差。

这段代码

 running_mean = momentum * running_mean + (1 - momentum) * sample_mean
 running_var = momentum * running_var + (1 - momentum) * sample_var

代表了一种测试时间的替代方法,它不需要论文中所需的额外估计步骤。对于替代移动平均线,我们只是使用基于动量参数的指数衰减模型来更新均值和方差。