使用 Keras 对 CNN 进行数据重塑

数据挖掘 Python 喀拉斯 重塑
2022-03-08 17:55:40

我是 Keras 的初学者。我已经在 Keras 中加载了 MNIST 数据集并检查了它的维度。代码是

from keras.datasets import mnist

# load data into train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()

print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)
print("Shape: ", X_train[0].shape)

输出是

(60000, 28, 28, 1)
(60000, 10, 2, 2, 2, 2)
(10000, 28, 28, 1)
(10000, 10, 2, 2)
Shape:  (28, 28, 1)

由于 X_train 和 X_test 已经在形状中(#sample, width, height, #channel)。我们还需要重塑吗?为什么?我正在关注的教程使用以下重塑代码:

X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1).astype('float32')

我的第二个问题是为什么.astype('float32')在代码中使用 is ?

最后,我无法理解print(y_train.shape)and的输出print(y_test.shape)

请建议。我已经阅读了Reshaping of data for deep learning using Keras但是我的疑虑仍然不清楚。

2个回答

回答 1 整形的原因是为了保证模型的输入数据是正确的形状。但是你可以说使用 reshape 是一种努力的复制。

回答2 转换为float的原因,以便稍后我们可以在0-1范围内标准化图像而不会丢失信息。

我不明白那些形状。使用您的代码,我得到:

(60000, 28, 28) (60000,) (10000, 28, 28) (10000,) 形状: (28, 28)

这更有意义。例如,您的输出不应有 6 个张量维度:“(60000, 10, 2, 2, 2, 2)”。