如何正确使用 Flatten 层?

人工智能 机器学习 训练 图像识别 建筑学
2021-11-11 04:08:21

语境

我正在尝试创建能够识别类似印刷数字的网络。类似于 MNIST,但仅适用于标准打印字体。

图像的大小为 40x40,我想将它们放入前馈网络中,因为 ConvNet 对于这项任务来说似乎太强大了。

问题

我应该如何在这个任务中使用 Flatten 层?

代码

我现在的网络:

X, test_X, y, test_y = train_test_split(X, y, test_size=0.25, random_state=42)

self.model = Sequential()
self.model.add(Flatten())
self.model.add(Dense(64, activation='relu', input_shape=X.shape[1:]))
self.model.add(Dense(no_classes, activation='softmax'))
self.model.compile(loss="categorical_crossentropy",
                   optimizer="rmsprop",
                   metrics=['accuracy'])

self.history = self.model.fit(X, y, batch_size=256, epochs=20, validation_data=(test_X, test_y))
print(self.model.summary())

示例图像

在此处输入图像描述 在此处输入图像描述 在此处输入图像描述

当前结果

在此处输入图像描述 在此处输入图像描述

1个回答

Flatten 层用于将 ND 张量折叠成一维张量。在您的情况下,输入似乎是28×28图像,因此 Flatten 会将其转换为具有形状的张量1×768. 请注意,不会丢失任何信息。展平层通常用于具有尺寸的卷积层N×M×C(在哪里N,M是特征图大小和C是通道数),并且希望与 Dense 层或仅接受 1D 输入的其他层完全连接。当网络旨在使用不同的技术从最终卷积层输出特征向量以用于图像分类目的时,也可以使用 Flatten。