我正在将 Adversarial Training for Semantic Segmentation 与Adversarial Learning for Semi-Supervised Semantic Segmentation结合起来。
这个想法是这样的:鉴别器将 21 个类别(PASCAL VOC 数据集)的概率图(21x321x321)作为输入,并生成大小为 2x321x321 的置信度图。(每个像素的真/假决定)。对于输入图像(3x321x321),分割网络(生成器)生成“假”概率图(21x321x321)。“真实”概率图来自地面实况分割标签(使用 one-hot 编码的 21x321x321)。通过代码更容易理解生成器和判别器的损失。
在我目前的实现中,判别器损失很快变为 0,这对于 GAN 训练(这里提到)来说更失败了。我是 pytorch 的新手(以及对抗性训练)。所以,我不确定我的网络架构、超参数或只是我的训练方案是否存在问题。我真的很感激一些指示。
这是我的训练的样子:
[1] [0] LD:1.402573823928833 LG:3.0447137355804443
[1] [1] LD:0.8725658655166626 LG:2.544170618057251
[1] [2] LD:0.6969347596168518 LG:2.1177046298980713
[1] [3] LD:0.611475944519043 LG:1.7778557538986206
[1 ] [4] LD:0.49319764971733093 LG:2.1366050243377686
[1] [5] LD:0.30195319652557373 LG:1.7873120307922363
[1] [6] LD:0.14412544667720795 LG:1.045764684677124
[1] [7] LD:0.04816107824444771 LG:1.5864180326461792
[1] [ 8] LD:0.012304163537919521 LG:1.370680332183838
[1] [9] LD:0.0035684951581060886 LG:1.3428194522857666
[1] [10] LD:0.0011156484251841903 LG:1.145486831665039
[1] [11] LD:0.00045744137605652213 LG:1.371126651763916
[1] [12] LD:0.0001588731538504362 LG:1.378540277481079
[1] [13] LD:4.844377326662652C-05 LG :
1.504058837890625
1.5183023606368806e-05 LG:1.584553837776184
[1] [16] LD:2.0302868506405503e-05 LG:1.4818311929702759
[1] [17] LD:1.0679158549464773e-05 LG:1.2976796627044678
[1] [18] LD:1.5313835319830105e-06 LG :1.2631664276123047
[19] ld :
4.273606009519426048 [22] LD:1.0100056897499599e-06 LG:1.581740140914917
[1] [23] LD:5.68767646313972360-08 LG:1.0763123035430908
[1] [24] LD:1.475878548262699C-07 LG:1.6125952005386353
[1] [25] LD:6.919402721905499C-07 LG:1.6719598770141602
[1] [26] [26 ] LD:1.3498377526843797e-08 LG:1.1914349794387817
[1] [27] LD:1.3576584301233652e-08 LG:1.1994632482528687
[1] [28] LD:3.087819067104647e-08 LG:1.2909866571426392
[1] [29] LD:3.416153049329296e -07 LG:2.143049478530884
[1] [30] LD:4.477038118011478e-08 LG:1.7709745168685913
[1] [31] LD:2.1782324832742006e-09 LG:1.2023413181304932
[1] [32] LD:1.0589346999267946e-07 LG:1.4242452383041382
我的训练代码是这样的:
generator = deeplabv2.Res_Deeplab()
optimizer_G = optim.SGD(filter(lambda p: p.requires_grad, \
generator.parameters()),lr=0.00025,momentum=0.9,\
weight_decay=0.0001,nesterov=True)
discriminator = Dis(in_channels=21)
optimizer_D = optim.Adam(filter(lambda p: p.requires_grad, \
discriminator.parameters()),lr=0.0001,weight_decay=0.0001)
for epoch in range(args.start_epoch,args.max_epoch+1):
for batch_id, (img,mask,ohmask) in enumerate(trainloader):
img,mask,ohmask = Variable(img.cuda()),Variable(mask.cuda(),requires_grad=False),\
Variable(ohmask.cuda(),requires_grad=False)
# ohmask : mask (HxW) converted to one-hot encoded probability map for each class( 21xHxW )
out_img_map = generator(img)
out_img_map = nn.LogSoftmax()(out_img_map)
#######################
# Adverarial Training#
#######################
if args.mode == 'adv':
N = out_img_map.size()[0]
H = out_img_map.size()[2]
W = out_img_map.size()[3]
# Generate the Real and Fake Labels
target_fake = Variable(torch.zeros((N,H,W)).long().cuda(),requires_grad=False)
target_real = Variable(torch.ones((N,H,W)).long().cuda(),requires_grad=False)
#########################
# Discriminator Training#
#########################
# Train on Real
conf_map_real = nn.LogSoftmax()(discriminator(ohmask.float()))
optimizer_D.zero_grad()
LD_real = nn.NLLLoss2d()(conf_map_real,target_real)
LD_real.backward()
# Train on Fake
conf_map_fake = nn.LogSoftmax()(discriminator(Variable(out_img_map.data)))
LD_fake = nn.NLLLoss2d()(conf_map_fake,target_fake)
LD_fake.backward()
# Update Discriminator weights
optimizer_D.step()
######################
# Generator Training #
#####################
conf_map_fake = nn.LogSoftmax()(discriminator(out_img_map))
LG_ce = nn.NLLLoss2d()(out_img_map,mask)
LG_adv = args.lam_adv * nn.NLLLoss2d()(conf_map_fake,target_real)
LG_seg = LG_ce + args.lam_adv * LG_adv
optimizer_G.zero_grad()
LG_ce.backward(retain_variables=True)
LG_adv.backward()
optimizer_G.step()
print("[{}][{}] LD: {} LG: {}".format(epoch,i,(LD_real + LD_fake).data[0],LG_seg.data[0]))
我使用 Resnet-101 作为分割网络(这是我的生成器),我的鉴别器如下:
class Dis(nn.Module):
"""
Discriminator Network for the Adversarial Training.
"""
def __init__(self,in_channels,negative_slope = 0.2):
super(Dis, self).__init__()
self._in_channels = in_channels
self._negative_slope = negative_slope
self.conv1 = nn.Conv2d(in_channels=self._in_channels,out_channels=64,kernel_size=4,stride=2,padding=2)
self.relu1 = nn.LeakyReLU(self._negative_slope,inplace=True)
self.conv2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=2)
self.relu2 = nn.LeakyReLU(self._negative_slope,inplace=True)
self.conv3 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=2)
self.relu3 = nn.LeakyReLU(self._negative_slope,inplace=True)
self.conv4 = nn.Conv2d(in_channels=256,out_channels=512,kernel_size=4,stride=2,padding=2)
self.relu4 = nn.LeakyReLU(self._negative_slope,inplace=True)
self.conv5 = nn.Conv2d(in_channels=512,out_channels=2,kernel_size=4,stride=2,padding=2)
def forward(self,x):
x= self.conv1(x) # -,-,161,161
x = self.relu1(x)
x= self.conv2(x) # -,-,81,81
x = self.relu2(x)
x= self.conv3(x) # -,-,41,41
x = self.relu3(x)
x= self.conv4(x) # -,-,21,21
x = self.relu4(x)
x = self.conv5(x) # -,-,11,11
# upsample
x = F.upsample_bilinear(x,scale_factor=2)
x = x[:,:,:-1,:-1] # -,-, 21,21
x = F.upsample_bilinear(x,scale_factor=2)
x = x[:,:,:-1,:-1] # -,-,41,41
x = F.upsample_bilinear(x,scale_factor=2)
x = x[:,:,:-1,:-1] #-,-,81,81
x = F.upsample_bilinear(x,scale_factor=2)
x = x[:,:,:-1,:-1] #-,-,161,161
x = F.upsample_bilinear(x,scale_factor=2)
x = x[:,:,:-1,:-1] # -,-,321,321
return x
提前致谢!