为什么批处理大小受 RAM 限制?

数据挖掘 机器学习 训练
2022-02-25 08:48:42

更改网络的参数以最小化小批量的损失,但通常小批量的损失只是每个数据上单独损失的(加权)总和。松散地,我将其表示为

dT=1batch_sizeibatchdTi

在哪里dT是批次的净参数的更新,并且dTi仅用于一个训练示例。为什么不能dT然后“在线”计算,其中唯一需要的 RAM 是部分总和dT和任何一个dTi你在那一刻工作?

1个回答

与您描述的类似的东西在某些领域经常使用,它被称为梯度累积通俗地说,它包括在不更新权重的情况下计算几个批次的梯度,并且在 N 个批次之后,聚合梯度并应用权重更新。

这当然允许使用大于 GPU 内存大小的批量大小。

对此的限制是至少一个训练样本必须适合 GPU 内存。如果不是这种情况,可以使用其他技术,如梯度检查点。