如何在生成网络中抑制先前的结果?

数据挖掘 神经网络 张量流 生成模型
2021-09-24 01:37:22

想想一个生成动物名字的生成网络。

它已经接受了一组动物名称的训练,所以这应该不会太难。

但是假设我想生成 10 个动物名称。首先我运行网络并生成 ZEBRA。接下来它会产生 HORSE。接下来它再次产生 HORSE。这不好。

我需要一种方法来抑制以前的答案。

既然我知道产生 ZEBRA 和 HORSE 这两个词的权重模式,我是否可以利用这些知识来调整网络,以便能够给我一个新的独特结果。

(我不是说我想要一次全部 10 个答案!我的意思是我想一次又一次地运行网络,但每次都会得到不同的结果。)

换句话说,我想抑制给我先前答案的最大值。

(对于一个可以不重复列出 20-30 个动物名称的人来说,这项任务似乎很容易。我不确定这是因为以前的动物名称存储在短期记忆中,还是以前动物的突触被某种方式抑制了)。

1个回答

我需要一种方法来抑制以前的答案。

问题的答案取决于您使用的特定网络。它是什么样的生成模型?甘?自动编码器?Seq2Seq RNN?

话虽如此,如果你的网络一直输出相同的结果,这通常被称为“模式崩溃”,这是 GAN 的典型特征。如果您的网络是 GAN,则可能值得探索有效帮助您处理模式崩溃的Wasserstein 损失。

通常,在生成网络的输出中实现多样性(和非重复结果)的标准路径是增加您的训练数据集。增强是指以不同的方式表示相同的数据。在动物图片的背景下,想象它提供了斑马或马的不同图片,从不同的角度或在不同的光线条件下等。

还要确保您的训练数据集是平衡的,即所有类在训练集中均等地表示,否则结果将有偏差。

如果您使用生成动物字符串名称的 seq2seq 网络,您通常会逐个字符地对这些名称进行采样。在这种情况下有不同的采样策略,例如贪婪采样(总是最可能的字符)、多项式采样(光束搜索),它们也可以具有温度缩放。这些策略产生不同的结果,值得探索。增强此类网络输出多样性的一个非常重要的指标是采样序列的似然性(或其负对数似然性)。您希望将这种可能性包含在损失函数中,以确保您“推动”网络以相等的概率对其范围内的所有内容进行采样。这在本出版物本出版物中得到了很好的解释,它使用类似于您的动物名称任务的分子的字符串表示。

如果您将自动编码器用作生成网络,则可以将其与优化算法(例如粒子群优化器)结合使用,以在其潜在空间内导航并自动采样不同的解决方案。这里的关键是优化算法的成本函数,因为在其中你需要惩罚重复的答案。使用强化学习和类似的惩罚评分功能可以实现类似的性能。

我不确定这是因为以前的动物名称存储在短期记忆中,还是以前动物的突触被某种方式抑制了

如果您使用的是无状态LSTM 单元,则会在独立预测之间重置内存。如果你使用有状态的 LSTM 单元,之前的预测会影响下一个预测。

我意识到我的回答涵盖了许多不同的方法,但是它可能为您提供不同的探索方向。