计算一个时期内的批次数

数据挖掘 张量流 训练 时代
2022-02-15 04:55:06

我一直在阅读斯坦福大学深度学习课程的代码示例,我看到他们已经计算了num_steps = (params.train_size + params.batch_size - 1) // params.batch_size[github 链接]

为什么num_steps = params.train_size // params.batch_size不是呢?

1个回答

python中的双斜线代表“地板”除法(向下舍入到最接近的整数),所以如果结果不是整数,它总是会错过最后一批,它小于批大小。

例如:

给定一个包含 10,000 个样本和 15 批大小的数据集:

# Original calculation:
# (10000 + 15 - 1) // 15 = floor(667.6) = 667

# Your calculation:
# 10000 // 15 = floor(666.667) = 666

您的公式计算完整批次的数量,即它们的总批次数。