X_train, y_train 来自 ImageDataGenerator (Keras)

数据挖掘 Python 喀拉斯 美国有线电视新闻网
2022-03-07 20:15:36

我可以从 data_generator 获得 X_train、y_train、X_test、y_test 吗?这是我的代码:

data_generator = ImageDataGenerator(
    rescale = 1. / 255, 
    shear_range = 0.2, 
    zoom_range = 0.2, 
    horizontal_flip = True,
    vertical_flip = True,
    rotation_range = 180,
    width_shift_range = 0.2,
    height_shift_range = 0.2,
    validation_split = 0.2) 

train_generator = data_generator.flow_from_directory(
    train_data_dir, 
    target_size =(img_width, img_height), 
    batch_size = batch_size,
    shuffle = True,
    class_mode = 'categorical',
    seed = 42,
    subset='training')

validation_generator = data_generator.flow_from_directory( 
    train_data_dir, 
    target_size =(img_width, img_height), 
    batch_size = batch_size,
    shuffle = True,
    class_mode = 'categorical',
    seed = 42,
    subset='validation')
4个回答

在python 2中:

X_train, y_train = train_generator.next() X_test, y_test = validation_generator.next()

在python 3中:

X_train, y_train = next(train_generator) X_test, y_test = next(validation_generator)

根据上面的答案,下面的代码只给出了 1 批数据。

X_train, y_train = next(train_generator)
X_test, y_test = next(validation_generator)

要从 train_generator 中提取完整数据,请使用以下代码 -

# Store the data in X_train, y_train variables by iterating over the batches
train_generator.reset()
X_train, y_train = next(train_generator)
for i in tqdm(range(int(len(train_generator)/batch_size)-1)): #1st batch is already fetched before the for loop.
  img, label = next(train_generator)
  X_train = np.append(X_train, img, axis=0 )
  y_train = np.append(y_train, label, axis=0)
print(X_train.shape, y_train.shape)

采用

X_train, y_train = train_generator.next()
X_test, y_test = validation_generator.next()

y_train, y_test 值将基于您在 train_data_dir 中的类别文件夹。不是值会像 0,1,2,3... 按字母顺序映射到类名。

否则,使用下面的代码来获取索引图

train_generator.class_indices
validation_generator.class_indices

确保它们都相同。

更多的是间接答案,但可能对某些人有所帮助:这是我用来将测试和训练图像分类到相应(子)文件夹中的脚本,以使用 Keras 和数据生成器功能(MS Windows)。

import os
from glob import glob
from shutil import copyfile

############################
# Data stored at...
odir = "C:/origdir/"
# Target dir for test-train split
tdir = "C:/myimages/"
# Test-train-split (= x * maxsamples)
trainsize = 0.8
# Define max numer of samples for test-train
maxsamples = 2500

paths = glob(str(odir)+"*")

for p in paths:
    # get name of dir
    classname = p[p.rfind("\\")+1:]
    ###########################################
    # Gen dirs in tt
    # Check/create dir
    # TEST
    try:
        os.makedirs(str(tdir) + "/val/" + str(classname))
    except FileExistsError:
        pass
    # TRAIN
    try:
        os.makedirs(str(tdir) + "/train/" + str(classname))
    except FileExistsError:
        pass
    # ###########################################
    # COPY
    # train samples
    filelist = os.listdir(p)
    filelist = filelist[:maxsamples]
    tindex = int(trainsize*len(filelist))
    trainfiles = filelist[:tindex]
    testfiles = filelist[tindex:]
    # train
    for f in trainfiles:
        copyfile(p + "/"+  f, str(tdir) + "/train/" + str(classname) + "/" + f)
        #os.remove(p + "/"+  f)
    # test
    for f in testfiles:
        # get filename
        copyfile(p + "/"+  f, str(tdir) + "/val/" + str(classname) + "/" + f)
        #os.remove(p + "/"+  f)