语境
我正在尝试创建能够识别类似印刷数字的网络。类似于 MNIST,但仅适用于标准打印字体。
图像的大小为 40x40,我想将它们放入前馈网络中,因为 ConvNet 对于这项任务来说似乎太强大了。
问题
我应该如何在这个任务中使用 Flatten 层?
代码
我现在的网络:
X, test_X, y, test_y = train_test_split(X, y, test_size=0.25, random_state=42)
self.model = Sequential()
self.model.add(Flatten())
self.model.add(Dense(64, activation='relu', input_shape=X.shape[1:]))
self.model.add(Dense(no_classes, activation='softmax'))
self.model.compile(loss="categorical_crossentropy",
optimizer="rmsprop",
metrics=['accuracy'])
self.history = self.model.fit(X, y, batch_size=256, epochs=20, validation_data=(test_X, test_y))
print(self.model.summary())
示例图像
当前结果