保持批量大小为 2 的幂有什么好处?

数据挖掘 机器学习 训练
2021-09-22 20:42:32

在机器学习中训练模型时,为什么有时将批量大小保持为 2 的幂是有利的?我认为最好使用最适合您的 GPU 内存/RAM 的大小。

这个答案声称对于某些包,2 的幂作为批量大小更好。有人可以为此提供详细解释/链接到详细解释吗?这适用于所有优化算法(梯度下降、反向传播等)还是仅其中一些?

4个回答

这是虚拟处理器 (VP) 与 GPU 的物理处理器 (PP) 对齐的问题。由于 PP 的数量通常是 2 的幂,因此使用与 2 的幂不同的 VP 数量会导致性能不佳。
您可以将 VP 到 PP 的映射视为一堆大小为 PP 数量的切片。
假设你有 16 PP。
您可以在它们上映射 16 个 VP:1 个 VP 映射到 1 个 PP。
您可以在它们上映射 32 VP:2 片 16 VP,1 PP 将负责 2 VP。
等等。在执行过程中,每个 PP 将执行他负责的第一个 VP 的作业,然后执行第二个 VP 的作业等。
如果使用 17 个 VP,每个 PP 将执行其第一个 PP 的作业,那么 1 个 PP 将执行执行第 17 项的工作其他的什么都不做(具体如下)。
这是由于GPU 使用的 SIMD 范式(在 70 年代称为矢量)。这通常被称为数据并行:所有的 PP 在同一时间做同样的事情,但在不同的数据上。这里
更准确地说,在有 17 个 VP 的示例中,一旦第一个切片的工作完成(所有 PP 都在做他们第一个 VP 的工作),所有的 PP 将做同样的工作(第二个 VP),但只有一个有一些要处理的数据
跟学习没关系。这只是编程的东西。

总体思路是将您的 mini-batch 完全安装在 CPU/GPU 中。由于所有 CPU/GPU 都具有 2 的幂的存储容量,因此建议将 mini-batch 大小保持为 2 的幂。

我刚刚在 coco 上运行了一个快速实验训练yolov4-csp,批处理大小为 8 和 9,发现每张图像,批处理大小 9 比 8 稍微高效一些。所以至少在现代 GPU 上使用 pytorch 和相对较小的批处理(2080Ti ) 似乎不使用 2 的幂作为批量大小不会对性能产生负面影响。

测量:

batch=9
4009/13143 batches
real    20m51.557s
per batch time = (20*60 + 51.557)/4009 = 0.312186829633325 seconds
per image time = 0.312186829633325 / 9 = 0.0347

batch=8
5037/14786 batches
real    23m51.666s
per batch time = (23*60 + 51.666)/5037 = 0.28422989874925547 seconds
per image time = 0.28422989874925547 / 8 = 0.0355 seconds

查看此实验数据,了解每个样本的平均预测速度与批量大小的关系。它非常强调了 jcm69 的公认答案的要点。看起来这个特定模型(及其输入)在批量大小为 32 的倍数时效果最佳 - 请注意位于主点线下方的稀疏点线。对于其他模型-GPU 组合,这可能会有所不同,但对于任何组合来说,2 的幂都是安全的选择。

在此处输入图像描述

不幸的是,ezekiel 的基准测试并不是很能说明问题,因为 9 的批处理大小可能会分配两倍的内存。因此,批量大小为 9 比批量大小为 8 更快是意料之中的。