简而言之
如果你看看是什么mnist.train,你会发现其中有两个 numpy 数组:
mnist.train.images(shape (55000, 784)) 和mnist.train.labels(shape (55000, 10))。然后你所要做的就是迭代这些。
例子
这个例子有点复杂,所以它对你选择的任何大小都是健壮的,但保留了固定多个批次的想法,作为你想在数据集上训练多少的主要信息:
X_train = mnist.train.images
Y_train = mnist.train.labels
number_of_examples = X_train.shape[0]
batch_size = 1000
number_of_batches = 200
epoch = 0
for i in range(number_of_batches):
j = (i - epoch) * batch_size % number_of_examples
k = (i - epoch + 1) * batch_size % number_of_examples
if (k < j):
k = number_of_examples
batch_x = X_train[j:number_of_examples, :]
batch_y = Y_train[j:number_of_examples, :]
print('Shuffling data to retrain on dataset')
data = numpy.concatenate((X_train, Y_train), axis=1)
np.random.shuffle(data)
# redefine X_train and Y_train from shuffled dataset
X_train = data[:, :784]
Y_train = data[:, 784:794]
epoch = i + 1
else:
batch_x = X_train[j:k, :]
batch_y = Y_train[j:k, :]
# [...]
train_step.run(feed_dict={x: batch_x, y_: batch_y, keep_prob: 0.5})
j是将成为批次的第一行的数据集的行
k是最后一行,因此j-k=batch_size每个批次的示例都符合预期。除非您已经到达数据集的末尾k<j(
您计算i - epoch,而不仅仅是i为了确保您始终从数据集的第一行开始,即使您已经完全完成了它。例如,如果number_of_examples = 55555,batch_size = 1000和
number_of_batches = 200: i = 56是您第一次再次浏览(打乱的)数据集,但56 * 1000 % 55555 = 445您将从第 445 行而不是 0 开始。
你也必须连接,否则对(例如,标签)将丢失。
i, j ,k % 什么?
运行下面的玩具代码以查看迭代索引的行为:
number_of_examples = 55555
batch_size = 1000
number_of_batches = 200
epoch = 0
for i in range(number_of_batches):
j = (i - epoch) * batch_size % number_of_examples
k = (i - epoch + 1) * batch_size % number_of_examples
if (k < j):
print('Overlap : ', j, number_of_examples)
epoch = i + 1
else:
print(i, ' ', j, k)