python中的MLP批量迭代

数据挖掘 Python 神经网络 scikit-学习
2022-02-25 10:06:26

我在 sklearn 中使用 MLPRegressor 来训练一个具有大约 1000 个输入和一个连续输出变量的网络。本质上,问题是图像分类(1000 像素)之一,其中像素的分布与连续变量输出相关。

我可以生成一个任意大的训练集,但是在将数组加载到 numpy. 我选择了在有限(5k)训练集下给出合理结果的超参数,但我试图找出用两个数量级以上的数据训练网络的最佳方法。有没有办法将小批量传递给我自己,以便我可以管理内存和 IO?如果我可以设置一种方法来确保我以随机方式读取每个训练步骤的批次,那么训练和重新训练网络是否有用?

我猜这是一个常见问题,但我正在努力寻找一个明智的解决方案。

1个回答

Scikit-learn MLPRegressor有一种.partial_fit批量训练方法,可以克服这个内存问题。