我正在浏览这个 imagenet 示例。
并且,在第 88 行,使用了模块 DistributedDataParallel。当我在docs中搜索相同内容时,我什么也没找到。但是,我找到了DataParallel
.
所以,想知道 DataParallel 和 DistributedDataParallel 模块有什么区别。
我正在浏览这个 imagenet 示例。
并且,在第 88 行,使用了模块 DistributedDataParallel。当我在docs中搜索相同内容时,我什么也没找到。但是,我找到了DataParallel
.
所以,想知道 DataParallel 和 DistributedDataParallel 模块有什么区别。
由于分布式 GPU 功能才出现几天[在 Pytorch 的 v2.0 发行版中],因此仍然没有关于此的文档。所以,我必须通过源代码的文档字符串来找出差异。因此,DistributedDataParallel
模块的文档字符串如下:
在模块级别实现分布式数据并行。此容器通过在批处理维度中分块将输入拆分到指定的设备,从而并行化给定模块的应用程序。该模块在每台机器和每台设备上复制,每个这样的副本处理输入的一部分。在向后传递期间,来自每个节点的梯度被平均。批量大小应大于本地使用的 GPU 数量。它还应该是 GPU 数量的整数倍,以便每个块的大小相同(以便每个 GPU 处理相同数量的样本)。
的文档字符串dataparallel
如下:
在模块级别实现数据并行。此容器通过在批处理维度中分块将输入拆分到指定的设备,从而并行化给定模块的应用程序。在前向传递中,模块在每个设备上复制,每个副本处理一部分输入。在向后传递期间,来自每个副本的梯度被汇总到原始模块中。批量大小应大于使用的 GPU 数量。它还应该是 GPU 数量的整数倍,以便每个块的大小相同(以便每个 GPU 处理相同数量的样本)。
Pytorch 论坛上的这个回复也有助于理解两者之间的区别,
DataParallel
更容易调试,因为您的训练脚本包含在一个进程中。DataParallel
也可能导致 GPU 利用率不佳,因为一个主 GPU 必须保存所有 GPU 的模型、组合损失和组合梯度。
有关更详细的说明,请参见此处。