训练 VAE 时 KL 散度损失为零

数据挖掘 神经网络 深度学习 喀拉斯 自动编码器
2022-02-15 06:56:11

我正在尝试训练一个变分自动编码器来执行天文图像的无监督分类(它们的大小为 63x63 像素)。我正在使用具有 2 个卷积层和一个密集层的编码器,以及用于解码器的类似结构。我正在执行渐变的 Xavier 初始化。我正在使用学习率为 1e-4 的 Adam 优化器。

我观察到 KL 散度从非常小的值开始(大约 1e-4 的数量级)并在训练的几个 epoch 后突然消失,而我的重建损失正常减少(我使用 MSE 作为重建损失)。可能是什么原因?我应该单独执行损失的缩放吗?

这是我的模型。

initializer = glorot_normal()

def sampling(inputs):
  z_mean, z_log_var = inputs
  epsilon = k.random_normal(shape=(k.shape(z_mean)[0], 2), mean=0., stddev=0.1)
  return z_mean + k.exp(z_log_var) * epsilon

input_images = Input(shape = (63,63,1))
conv1 = Conv2D(16, (3,3), activation = 'relu')(input_images)
conv2 = Conv2D(8, (3,3), activation = 'relu')(conv1)
flattened = Flatten()(conv2)
x = Dense(4, activation = 'relu')(flattened)

z_mean = Dense(2, name = "z_mean")(x)
z_log_var = Dense(2, name = "z_log_var")(x)
z = Lambda(sampling, output_shape = (2,))([z_mean, z_log_var])

encoder = Model(input_images, [z_mean, z_log_var, z], name = "encoder")

latent_inputs = Input(shape = (2,))
x = Dense(59*59*8, activation = 'relu')(latent_inputs)
x = Reshape((59,59,8))(x)
conv4 = Conv2DTranspose(8, (3,3), activation = 'relu')(x)
decoded = Conv2DTranspose(1, (3,3), activation = 'softmax')(conv4)

decoder = Model(latent_inputs, decoded, name = "decoder")

z_mean, z_log_var, z = encoder(input_images)
vae_decoder_output = decoder(z)
vae = Model(input_images, vae_decoder_output, name = "VAE")
vae.summary()

这是我试图实现的损失函数。

recon = MSE(input_images, vae_decoder_output)
recon = k.mean(recon)
kl_loss = 1 + z_log_var - k.square(z_mean) - k.exp(z_log_var)
kl_loss = k.sum(kl_loss, axis = -1)
kl_loss *= -0.5
vae_loss = k.mean(kl_loss*10**3+recon)
vae.add_loss(vae_loss)
1个回答

看看鲍曼的论文“从连续空间生成句子”。在第 3.1 节中解释了为什么 LSTM_VAE 倾向于这种行为:

“这种有问题的学习趋势由于 lstm 解码器对隐藏状态的细微变化的敏感性而更加复杂,例如后验采样过程引入的变化。这导致模型最初学会忽略 ~z 并追求低垂的果实,解释使用更容易优化的解码器的数据。一旦发生这种情况,解码器会忽略编码器,并且几乎没有梯度信号在两者之间传递,从而在 kl 成本项为零时产生不希望的稳定平衡。我们提出了两种技术来缓解这种情况问题。”

……