WGAN-GP 慢评论家训练时间

数据挖掘 张量流 生成模型
2022-02-21 21:33:15

我正在使用 Tensorflow 2.0 实现 WGAN-GP,但是critic 的每次训练迭代都非常慢(我的 CPU 大约 4 秒,而 Colab GPU 大约 9 秒)。

WGAN-GP 通常这么慢还是我的代码有缺陷?

这是我训练评论家的代码:

def train_critic(self, X_real, batch_size, gp_loss_factor, optimizer):
    y_real = np.ones((batch_size, 1))

    # Get batch of generated images
    noise = np.random.normal(0, 1, (batch_size, self.z_dim))
    X_fake = self.gen.predict(noise)
    y_fake = -np.ones((batch_size, 1))

    X = np.vstack((X_real, X_fake))
    y = np.concatenate((y_real, y_fake))

    # Interpolate images
    alpha = np.random.uniform(size=(batch_size, 1, 1, 1))
    X_interpolated = alpha * X_real + (1 - alpha) * X_fake
    X_interpolated = tf.constant(X_interpolated, dtype=tf.float32)

    # Perform weight update
    with tf.GradientTape() as outer_tape:
        # Calculate gradient penalty loss
        with tf.GradientTape() as inner_tape:
            inner_tape.watch(X_interpolated)
            y_interpolated = self.critic(X_interpolated)
        gradients = inner_tape.gradient(y_interpolated, X_interpolated)
        norm = tf.sqrt(
            1e-8 + tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
        gp_loss = gp_loss_factor * tf.reduce_mean((norm - 1.) ** 2)

        # Calculate Wasserstein loss
        y_pred = self.critic(X)
        wasserstein_loss = wasserstein(y, y_pred)

        # Add two losses
        loss = tf.add_n([wasserstein_loss, gp_loss] + self.critic.losses)
    gradients = outer_tape.gradient(loss, self.critic.trainable_variables)

    optimizer.apply_gradients(zip(gradients, self.critic.trainable_variables))

    return wasserstein_loss, gp_loss

def wasserstein(y_true, y_pred):
    return tf.reduce_mean(y_true * y_pred)
1个回答

这部分计算二阶导数。它应该是缓慢且耗时的。

# Perform weight update
with tf.GradientTape() as outer_tape:
    # Calculate gradient penalty loss
    with tf.GradientTape() as inner_tape:

尝试

# Perform weight update
with tf.GradientTape() as outer_tape, tf.GradientTape() as inner_tape: