我试图了解如何将数据馈送到LSTM层Keras,但我遇到了麻烦,我不明白该怎么做。我有一个由单词组成的数据集,每个单词都嵌入了一个由 839 个元素组成的向量,所以我的数据集的形状是 (x, 839)。
我想将我的数据集输入到LSTM图层中,但我没有正确理解 Keras 想要的 3D 对象,由(batch_size, timesteps, feature). 我想每次输入一个单词LSTM,我该怎么办?
提前致谢。
更新 我仍然收到有关输入形状的错误:
ValueError:检查输入时出错:预期 lstm_11_input 的形状为 (2, 839) 但得到的数组的形状为 (839, 1)
我batch_size_shape(batch_size, timesteps, feature)目前正在使用。这是代码:
class KendallTauHistory(Callback):
def __init__(self, dataset, y_true, groups):
self.y_true = y_true
self.dataset = dataset
self.groups = groups
def on_epoch_end(self, epoch, logs=None):
predictions = self.model.predict(self.dataset)
predictions = predictions.flatten()
predictions = list(map(lambda element: element + np.random.uniform(0.0, 1.0) * 0.02 - 0.01, predictions))
# For batch training
ranked_predictions = np.array([])
kendalls = np.array([])
start_range = 0
for group in self.groups:
end_range = (start_range + group[1]) # Batch is a group of words with same group id
batch_predictions = predictions[start_range:end_range]
batch_labels = self.y_true[start_range:end_range]
batch_predictions = list(map(lambda element: element + np.random.uniform(0.0, 1.0) * 0.02 - 0.01, batch_predictions))
ranked_predictions = np.append(ranked_predictions, np.floor(rankdata(batch_predictions)))
kendalls = np.append(kendalls, kendalltau(batch_labels, batch_predictions))
start_range = end_range
#self.y_true = self.y_true[0:len(ranked_predictions)]
print('\nORIGINAL LABELS: {0}\n'.format(self.y_true))
print('PREDICTED LABELS: {0}'.format(ranked_predictions))
print("\nEpoch Kendall's tau: {0}".format(np.mean(kendalls)))
model = tf.keras.Sequential()
model.add(LSTM(units=10, batch_input_shape=(None, 2, 839)))
model.add(Dense(15, activation='sigmoid'))
model.summary()
model.compile(loss=listnet_loss, optimizer=keras.optimizers.Nadam(learning_rate=0.000005, beta_1=0.9, beta_2=0.999))
real_labels = np.array([])
losses = np.array([])
with tf.device('/GPU:0'):
model.fit(training_dataset, training_dataset_labels, epochs=10, workers=10,
verbose=1, callbacks=[KendallTauHistory(training_dataset, training_dataset_labels, groups_id_count)])