我正在尝试训练一个循环 GAN,它旨在生成地理空间运动数据(纬度、经度和时间的 3 元组序列)。您可以简单地将其视为具有 3 个特征的向量序列。
我目前在 TensorFlow 上使用 Keras。以下是我想出的两个网络。我知道超参数可能很糟糕。我只想让一般架构工作,然后再调整它们。
发电机:
generator = tf.keras.Sequential([
tf.keras.layers.GRU(256, return_sequences=True, stateful=True, input_shape=(None, 3), batch_size=1),
tf.keras.layers.TimeDistributed(tf.layers.Dense(3))
])
鉴别器:
discriminator = tf.keras.Sequential([
tf.keras.layers.GRU(128, input_shape=(None, 3)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
为了完整起见,我的组合对抗模型如下所示discriminator.trainable = False:
adversarial = tf.keras.Sequential([
generator,
discriminator
])
以下是我的训练循环。第 1 部分用于训练鉴别器。我抓取了一些真实的序列并运行一个循环以使用生成器生成一些假的序列,然后为它们创建标签并在两者上训练鉴别器都没有问题。
不过,我很难弄清楚如何训练生成器。如您所见,我需要在循环中调用生成器,以便一个接一个地生成序列的向量。我真的不能train_on_batch
用那个。
for epoch in range(EPOCHS):
for batch in range(n_batches):
# --- PART 1: Train the discriminator ----------------------------------
# Use the generator to generate a half a batch full of fake data
fake_data = []
for i in range(BATCH_SIZE // 2):
length = df_train.sample(1).squeeze().shape[0]
start = data_train[np.random.randint(data_train.shape[0])]
generated = np.array([start])
generator.reset_states()
for i in range(1, length):
input = generated[-1:]
input = np.array([input])
prediction = model.predict(input, batch_size=1)
prediction = np.squeeze(prediction, axis=0)
generated = np.concatenate([generated, prediction])
fake_data.append(generated)
fake_labels = np.zeros((BATCH_SIZE // 2, 1))
# Get half a batch of real data
real_data = df_train.iloc[batch * BATCH_SIZE:(batch + 1) * BATCH_SIZE].tolist()
real_labels = np.ones((BATCH_SIZE // 2, 1))
# Train the discriminator on half a batch of real and half a batch of fake data
d_loss_fake = discriminator.train_on_batch(fake_data, fake_labels)
d_loss_real = discriminator.train_on_batch(real_data, real_labels)
d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
# --- PART 2: Train the generator --------------------------------------
g_loss = # HERE, I AM STRUGGLING
如果您认为 RNN 不是解决此问题的方法,我愿意接受任何想法,甚至对整个方法进行更改。我决定使用 RNN,因为我需要生成序列,所以它们很自然地适合,很明显。此外,序列需要是可变长度的,因此使用某种完全连接的网络是可能的,但相当尴尬。
我真的希望有人可以提供帮助,因为我现在正试图解决这个问题。
请询问,如果您缺少任何信息。