使用 Keras 为深度学习重塑数据

数据挖掘 Python 神经网络 深度学习 喀拉斯
2021-10-06 02:36:19

我是 Keras 的初学者,我从 MNIST 示例开始了解该库的实际工作原理。Keras 示例文件夹中的 MNIST 问题的代码片段如下所示:

import numpy as np
np.random.seed(1337)  # for reproducibility

from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten  
from keras.layers import Convolution2D, MaxPooling2D
from keras.utils import np_utils

batch_size = 128
nb_classes = 10
nb_epoch = 12

# input image dimensions
img_rows, img_cols = 28, 28
# number of convolutional filters to use
nb_filters = 32
# size of pooling area for max pooling
nb_pool = 2
# convolution kernel size
nb_conv = 3

# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
..........

我无法理解这里的重塑功能。它在做什么以及我们为什么要应用它?

2个回答

mnist.load_data()提供具有结构的 MNIST 数字,(nb_samples, 28, 28)即每个示例具有 2 个维度,表示 28x28 的灰度图像。

然而,Keras 中的 Convolution2D 层被设计为每个示例使用 3 个维度。它们具有 4 维输入和输出。这涵盖了彩色图像(nb_samples, nb_channels, width, height),但更重要的是,它涵盖了网络的更深层,其中每个示例都变成了一组特征图,即(nb_samples, nb_features, width, height).

MNIST 数字输入的灰度图像需要不同的 CNN 层设计(或层构造函数的参数以接受不同的形状),或者该设计可以简单地使用标准 CNN,并且您必须将示例明确表示为 1 通道图片。Keras 团队选择了后一种方式,需要重新塑造。

只是对接受的答案的一个小修正,输入形状索引命名如下:(n_images,x_shape,y_shape,channels)