找到了一个解决方案,将 keras.utils.Sequence 类型的自定义批处理生成器传递给 model.fit 函数(可以编写任何逻辑来构造批处理和修改/增强训练数据),而不是将整个数据集传入一口气。相关代码供参考:
# Must implement the __len__ function returning the number
# of batches in this dataset, and the __getitem__ function
# that returns a tuple (inputs, labels).
# Optionally, on_epoch_end() can be implemented which as the
# name suggest is called at the end of each epoch. Here one
# can e.g. shuffle the input data for the next epoch.
class BatchGenerator(keras.utils.Sequence):
def __init__(self, inputs, labels, padding, batch_size):
self.inputs = inputs
self.labels = labels
self.padding = padding
self.batch_size = batch_size
def __len__(self):
return int(np.floor(len(self.inputs) / self.batch_size))
def __getitem__(self, index):
max_length = 0
start_index = index*batch_size
end_index = start_index+batch_size
for i in range(start_index, end_index):
l = len(self.inputs[i])
if l>max_length:
max_length = l
out_x = np.empty([self.batch_size, max_length], dtype='int32')
out_y = np.empty([self.batch_size, 1], dtype='float32')
for i in range(self.batch_size):
out_y[i] = self.labels[start_index+i]
tweet = self.inputs[start_index+i]
l = len(tweet)
for j in range(l):
out_x[i][j] = tweet[j]
for j in range(l, max_length):
out_x[i][j] = self.padding
return out_x, out_y
# The model.fit function can then be called like this:
training_generator = BatchGenerator(tokens_train, y_train, pad, batch_size)
model.fit(training_generator, epochs=epochs)