具有连续标签的条件 GAN 的鉴别器

数据挖掘 深度学习 回归 生成模型 嵌入 标签
2022-02-19 23:28:42

好的,假设我们有带有非离散标签(例如亮度或大小或其他东西)的标记良好的图像,我们希望根据它生成图像。如果使用离散标签完成,则可以这样做:

def forward(self, inputs, label):
    self.batch = inputs.size(0)
    h = self.res1(inputs)
    h = self.attn(h)

        ...

    h = self.res5(h)
    h = torch.sum((F.leaky_relu(h,0.2)).view(self.batch,-1,4*4), dim=2)
    outputs = self.fc(h)

    if label is not None:
        embed = self.embedding(label)
        outputs += torch.sum(embed*h,dim=1,keepdim=True)

嵌入可以匹配任何要添加到隐藏层的形状,并将嵌入与潜在相加,这迫使鉴别器识别它正在鉴别的类,以便对鉴别做出更好的判断。这很酷,但这种方法使用的是离散方法嵌入。连续标签怎么样?除了一些半监督方法外,我真的找不到这样做的方法,而我有确切的标签。有人可以帮助我吗?

1个回答

我认为答案是,您必须创建大量类别,或者与我提供的代码完全不同。 这篇论文很好地解决了这个问题。