tensorflow mnist 教程

数据挖掘 机器学习 Python 深度学习 张量流
2022-03-17 14:00:06

我有一个关于训练 mnist 数据库的 tensorflow 教程的问题,我如何在不使用 next_batch() 的情况下创建自己的批次,这个想法是用 50 个批次进行训练,然后是 100 个等等,但它必须是有序的

这是代码

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
sess.run(tf.initialize_all_variables())
for i in range(20000):
    batch = mnist.train.next_batch(50)#revisar aqui
    if i%100 == 0:
        train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})
        print("step %d, training accuracy %g"%(i, train_accuracy))
    train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
1个回答

简而言之

如果你看看是什么mnist.train,你会发现其中有两个 numpy 数组: mnist.train.images(shape (55000, 784)) 和mnist.train.labels(shape (55000, 10))。然后你所要做的就是迭代这些。

例子

这个例子有点复杂,所以它对你选择的任何大小都是健壮的,但保留了固定多个批次的想法,作为你想在数据集上训练多少的主要信息:

X_train = mnist.train.images
Y_train = mnist.train.labels
number_of_examples = X_train.shape[0]
batch_size = 1000
number_of_batches = 200
epoch = 0

for i in range(number_of_batches):
    

    j = (i - epoch) * batch_size % number_of_examples
    k = (i - epoch + 1) * batch_size % number_of_examples

    if (k < j):
        k = number_of_examples
        batch_x = X_train[j:number_of_examples, :]
        batch_y = Y_train[j:number_of_examples, :]

        print('Shuffling data to retrain on dataset')
        data = numpy.concatenate((X_train, Y_train), axis=1)
        np.random.shuffle(data)

        # redefine X_train and Y_train from shuffled dataset
        X_train = data[:, :784]
        Y_train = data[:, 784:794]

        epoch = i + 1
    else:
        batch_x = X_train[j:k, :]
        batch_y = Y_train[j:k, :]

    # [...]
    train_step.run(feed_dict={x: batch_x, y_: batch_y, keep_prob: 0.5})

j是将成为批次的第一行的数据集的行 k是最后一行,因此j-k=batch_size每个批次的示例都符合预期。除非您已经到达数据集的末尾k<j

您计算i - epoch,而不仅仅是i为了确保您始终从数据集的第一行开始,即使您已经完全完成了它。例如,如果number_of_examples = 55555,batch_size = 1000number_of_batches = 200: i = 56是您第一次再次浏览(打乱的)数据集,但56 * 1000 % 55555 = 445您将从第 445 行而不是 0 开始。

你也必须连接,否则对(例如,标签)将丢失。


i, j ,k % 什么?

运行下面的玩具代码以查看迭代索引的行为:

number_of_examples = 55555
batch_size = 1000
number_of_batches = 200
epoch = 0

for i in range(number_of_batches):

    j = (i - epoch) * batch_size % number_of_examples
    k = (i - epoch + 1) * batch_size % number_of_examples

    if (k < j):
        print('Overlap : ', j, number_of_examples)
        epoch = i + 1
    else:
        print(i, '   ', j, k)