Keras 输入形状返回错误

数据挖掘 喀拉斯
2022-02-07 13:41:20

我目前正在学习 Keras,并且对密集层的输入形状有疑问。我目前正在尝试 mnist 数据集。我知道火车图像的 input_shape 是(60000,28,28)我也知道 keras忽略第一个维度,因为它是批量大小,因此在密集模型中输入的输入形状应该是(28,28),但是当我输入形状为(784,)时,我得到一个错误,模型运行。可以有人请解释为什么会这样

(train_images, train_labels), (test_images, test_labels) = 
 mnist.load_data()

 print(train_images.shape)


 network = models.Sequential()
 network.add(layers.Dense(512, activation='relu', input_shape=(28,28)))
 network.add(layers.Dense(10, activation='softmax'))
2个回答

将输入连接到 Keras 中的 Dense 层时,您总是需要展平图片(请注意,CNN 或 RNN 并非如此)。原因是在构建密集层时,根据密集层代码,输入 dim 是您在输入 ( input_dim = input_shape[-1]) 中传递的最后一个元素。因此,尽管您传递的是 (28,28) 的输入,但 keras 认为形状只有 28。这也解释了为什么 (,784) 的输入确实有效。

您可以在此处查看密集层代码

在 Keras 中使用Sequential模型时,您始终必须为第一层(密集、卷积、LSTM 等)提供输入的形状,正如您在官方文档中看到的那样:

“模型需要知道它应该期望什么输入形状。因此,顺序模型中的第一层(并且只有第一层,因为后续层可以进行自动形状推断)需要接收有关其输入形状的信息。有有几种可能的方法来做到这一点。”

顺便说一句,在 Keras 中,您实际上不需要在密集层之前展平该层,它是自动完成的(请参阅底部官方 Keras 文档中关于密集层的说明)。

“注意:如果层的输入具有大于 2 的等级,则它在与内核的初始点积之前被展平。” -文档