输入 CNN 网络之前的批处理数据

数据挖掘 Python 神经网络 张量流 美国有线电视新闻网
2022-02-25 03:57:57

我正在做一个使用 CNN 模型对 CT 扫描图像进行分类的项目,图像大小很大,我想使用批处理的想法将其输入网络,尝试使用以下代码:

    # train_data size = 5460
    num_epochs = 14
    batch_size = 390
    batch = 0
    print("Starting training...")
    for epoch in range(num_epochs):
        train_batch = train_data[batch:batch_size]
        batch += batch_size
        batch_size += batch_size
        ep_loss = 0
        for data in train_batch:
            X = data[0]
            Y = data[1]
            _, c = sess.run([optimizer, cost], feed_dict={x_img: X, y_label: Y})

我的问题是:

1-这是进行批处理的正确方法吗?还是有更好的方法?

2-使用上面的代码,我正在使用“AdamOptimizer”,对于批处理的想法,这是一种很好的优化技术,还是我应该使用另一种?

1个回答

使用批处理来训练神经网络是一种很好的做法。正如Yann LeCun 所说

小批量训练对您的健康不利。更重要的是,这对您的测试错误不利。朋友不要让朋友使用大于 32 的 minibatch。

虽然,它不会帮助您处理大图像。这就是卷积的用途。

如果您使用的是 Keras,那么批处理实现完全适合您。如果您使用 Tensorflow,那么您可以使用 tf.data api 或自己实现。num_iterations = training_size // batch_size您提供的代码似乎适用于您的特定训练规模,但您可能希望使其适应各种训练规模(即使用在结束纪元之前也包括在内...

尽管您可能希望调整超参数以获得更好的结果,但我使用 Adam Optimizer 取得了不错的结果。