LCYE / GD.py
GD.py
Raw
from scipy import spatial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import functools
from torch.autograd import Variable
from torch.optim import lr_scheduler
from util.spectral import SpectralNorm
from util.gumbel import gumbel_softmax
import numpy as np
import math
from memory_module import MemModule

class CollaspeNet(nn.Module):
  """
  Defines a Coallspe Nemory Net
  """
  def __init__(self, num_classes, pool_dim, mem_dim, n_upsampling, temperature, use_gumbel,):
      super(CollaspeNet,self).__init__()
      self.temperature = temperature
      self.use_gumbel = use_gumbel
      self.pool_dim = pool_dim
      self.n_upsampling = n_upsampling
      self.mem_rep = MemModule(mem_dim=mem_dim, fea_dim=self.pool_dim)
      ## reid branch
      self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
      self.fc = self._construct_fc_layer(
          [self.pool_dim], self.pool_dim, dropout_p=None
      )
      self.classifier = nn.Linear(self.pool_dim, num_classes)
      ## mask branch
      self.up_layers = nn.ModuleList()
      self.gmp = nn.AdaptiveMaxPool2d((1,1))#or average pool
      # upsample
      for i in range(self.n_upsampling):
        
        self.up_layers.append(nn.Sequential(SpectralNorm(nn.ConvTranspose2d(pool_dim, int(pool_dim/2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)), nn.LeakyReLU(0.2, True)))
        pool_dim = int(pool_dim/2)
      #self.up_layers = nn.Sequential(*up_layers)

      self.end_1 = nn.Sequential(*[nn.Conv2d(pool_dim, 1, kernel_size=1, stride=1)])
      self.logsoftmax = nn.LogSoftmax(dim=1)


      self._init_params()

  def forward(self,spatial_feat, target_label=None, G_or_D = 'G'):
    if G_or_D == 'C':
      # reid flow
      mem, mem_att = self.mem_rep(spatial_feat, mode='w')
      x = spatial_feat + mem
      x_pool = self.global_avgpool(x)
      x_pool = x_pool.view(x_pool.size(0), -1)
      x_pool = self.fc(x_pool) # b 2048
      x_feat = self.classifier(x_pool)

      # mask flow
      out = self.gmp(spatial_feat)*spatial_feat
      for i in range(self.n_upsampling):
        out = self.up_layers[i](out)
      
      out = self.end_1(out)

      n,c,h,w = out.size()
      if self.temperature == -1: return x_feat, x_pool, torch.ones((n,1,h,w)).cuda()

      if not self.use_gumbel:
        logits = self.logsoftmax(out.view(n, -1))
        th, _ = torch.topk(logits, k=int(self.temperature), dim=1, largest=True)
        mask, zeros, ones = torch.zeros_like(logits).cuda(), torch.zeros(h*w).cuda(), torch.ones(h*w).cuda()
        for i in range(n):
          mask[i,:] = torch.where(logits[i,:]>=th[i, int(self.temperature)-1], ones, zeros)
        mask = mask.view(n, 1, h, w)
      elif self.use_gumbel:
        logits = gumbel_softmax(out.view(n, -1), k=int(self.temperature), T=self.T, hard=True, eps=1e-10).view(n, 1, h, w)
        mask = logits.cuda()

      return x_feat, x_pool, mask

    elif G_or_D == 'D':
      x_mem, _ = self.mem_rep(spatial_feat, mode = 'Dr')
      return x_mem + spatial_feat
    elif G_or_D == 'G':
      x_mem,_ = self.mem_rep(spatial_feat, target_label, mode='Gr')
      return x_mem + spatial_feat
    else:
      raise ValueError('Wrong flow selection')


  def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
      """Constructs fully connected layer

      Args:
          fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
          input_dim (int): input dimension
          dropout_p (float): dropout probability, if None, dropout is unused
      """
      if fc_dims is None:
          self.feature_dim = input_dim
          return None

      assert isinstance(
          fc_dims, (list, tuple)
      ), 'fc_dims must be either list or tuple, but got {}'.format(
          type(fc_dims)
      )

      layers = []
      for dim in fc_dims:
          layers.append(nn.Linear(input_dim, dim))
          layers.append(nn.BatchNorm1d(dim))
          layers.append(nn.ReLU(inplace=True))
          if dropout_p is not None:
              layers.append(nn.Dropout(p=dropout_p))
          input_dim = dim

      self.feature_dim = fc_dims[-1]

      return nn.Sequential(*layers)

  def _init_params(self):
      for m in self.modules():
          if isinstance(m, nn.Conv2d):
              nn.init.kaiming_normal_(
                  m.weight, mode='fan_out', nonlinearity='relu'
              )
              if m.bias is not None:
                  nn.init.constant_(m.bias, 0)
          elif isinstance(m, nn.BatchNorm2d):
              nn.init.constant_(m.weight, 1)
              nn.init.constant_(m.bias, 0)
          elif isinstance(m, nn.BatchNorm1d):
              nn.init.constant_(m.weight, 1)
              nn.init.constant_(m.bias, 0)
          elif isinstance(m, nn.Linear):
              nn.init.normal_(m.weight, 0, 0.01)
              if m.bias is not None:
                  nn.init.constant_(m.bias, 0)



class Pat_Discriminator(nn.Module):
  """
  Defines a PatchGAN discriminator
  Code based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
  """
  def __init__(self, input_nc, ndf=64, n_layers=3, norm='bn'):
    """Construct a PatchGAN discriminator
    Parameters:
        input_nc (int)  -- the number of channels in input images
        ndf (int)       -- the number of filters in the last conv layer
        n_layers (int)  -- the number of conv layers in the discriminator
        norm_layer      -- normalization layer
    """
    super(Pat_Discriminator, self).__init__()

    norm_layer = nn.BatchNorm2d if norm == 'bn' else nn.InstanceNorm2d
    if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
      use_bias = norm_layer.func != nn.BatchNorm2d
    else:
      use_bias = norm_layer != nn.BatchNorm2d

    kw = 4
    padw = 1
    sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
    nf_mult = 1
    nf_mult_prev = 1
    for n in range(1, n_layers):  # gradually increase the number of filters
      nf_mult_prev = nf_mult
      nf_mult = min(2 ** n, 8)
      sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True)]

    nf_mult_prev = nf_mult
    nf_mult = min(2 ** n_layers, 8)
    sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True)]
    sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
    self.model = nn.Sequential(*sequence)

  def forward(self, x):
    return self.model(x), torch.ones_like(x)


class MS_Discriminator(nn.Module):
  def __init__(self, input_nc, num_classes, ndf=64, n_layers=3, norm='bn', num_D=3, pool_dim =2048):
    super(MS_Discriminator, self).__init__()
    self.num_D = num_D
    self.n_layers = n_layers
    self.downsample = nn.AvgPool2d(kernel_size=3, stride=2, count_include_pad=False)
    self.same0 = SamePadding(kernel_size=3, stride=2)
    self.same1 = SamePadding(kernel_size=4, stride=2)
    self.same2 = SamePadding(kernel_size=4, stride=1)
    #self.Mask = Mask(norm, temperature, use_gumbel)

    for i in range(num_D):
      netD = sub_Discriminator(input_nc, num_classes, ndf, n_layers, norm)
      channel_in = nn.Conv2d(512, pool_dim,kernel_size = 1, stride=1)
      channel_out = nn.Conv2d(pool_dim, 512,kernel_size = 1, stride=1)
      setattr(self, 'D'+str(i)+'_channel_in', channel_in)
      setattr(self, 'D'+str(i)+'_channel_out', channel_out)
      for j in range(n_layers+2): setattr(self, 'D'+str(i)+'_layer'+str(j), getattr(netD, 'layer'+str(j)))                                   

  def single_forward(self, model, x, idx,C =None):
    result = [x]
    for i in range(len(model)): #j
      samepadding = self.same1 if i < len(model)-2 else self.same2
      out = model[i](samepadding(result[-1]))
      if out.size(1) == 512 and C !=None and idx == 1:
        out_in = getattr(self, 'D'+str(idx)+'_channel_in')(out)
        out_in = C(out_in, G_or_D = 'D') + out_in
        out = getattr(self, 'D'+str(idx)+'_channel_out')(out_in)

      result.append(out)
    return result[1:]

  def forward(self, x, C=None):        
    num_D = self.num_D
    proposal = []
    result = []
    mask = None
    input_downsampled = x
    for i in range(num_D):
      model = [getattr(self, 'D'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
      proposal.append(self.single_forward(model, input_downsampled, i, C)) #[[D2L0, D2L1,..., D2L4],[D1L0,...,D1L4],[D0L0,...,D0L4]]
      if i != (num_D-1): input_downsampled = self.downsample(self.same0(input_downsampled))
    for i in proposal: result.append(i[-1])
    #mask = self.Mask(x, proposal)
    return result, mask
        
# (64,128,256,512,1) 
class sub_Discriminator(nn.Module):
  def __init__(self, input_nc, num_classes, ndf=64, n_layers=3, norm='in'):
    super(sub_Discriminator, self).__init__()
    self.n_layers = n_layers

    use_bias = norm == 'in'
    norm_layer = nn.BatchNorm2d if norm == 'bn' else nn.InstanceNorm2d
    sequence = [[SpectralNorm(nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, bias=use_bias)), nn.LeakyReLU(0.2, True)]]
    nf = ndf
    for n in range(1, n_layers):
      nf_prev = nf
      nf = min(nf*2, 512)
      sequence += [[SpectralNorm(nn.Conv2d(nf_prev, nf, kernel_size=4, stride=2, bias=use_bias)), norm_layer(nf), nn.LeakyReLU(0.2, True)]]

    nf_prev = nf
    nf = min(nf*2, 512)
    sequence += [[SpectralNorm(nn.Conv2d(nf_prev, nf, kernel_size=4, stride=1, bias=use_bias)), norm_layer(nf), nn.LeakyReLU(0.2, True)]]
    sequence += [[nn.Conv2d(nf, num_classes+1, kernel_size=4, stride=1)]]

    for n in range(len(sequence)):
      setattr(self, 'layer'+str(n), nn.Sequential(*sequence[n]))

  def forward(self, input):
    res = [input]
    for n in range(self.n_layers+2):
      model = getattr(self, 'layer'+str(n))
      res.append(model(res[-1]))
    return res[1:]

class Mask(nn.Module):
  def __init__(self, norm, temperature, use_gumbel, fused=1):
    super(Mask, self).__init__()
    self.temperature = temperature
    self.use_gumbel = use_gumbel
    self.fused = fused
    self.T = nn.Parameter(torch.Tensor([1]))
    norm_layer = nn.BatchNorm2d if norm == 'bn' else nn.InstanceNorm2d
    small_channels = [512, 512, 256, 128]
    big_channels = [512+128, 512+128+64, 128+64, 64] if self.fused == 2 else [512, 512, 128, 64]

    self.up32_16 = UpLayer(big_channels=big_channels[0], out_channels=512, small_channels=small_channels[0], norm_layer=norm_layer)
    self.up16_8 = UpLayer(big_channels=big_channels[1], out_channels=256, small_channels=small_channels[1], norm_layer=norm_layer)
    self.up8_4 = UpLayer(big_channels=big_channels[2], out_channels=128, small_channels=small_channels[2], norm_layer=norm_layer)
    # self.up4_2 = UpLayer(big_channels=big_channels[3], out_channels=64, small_channels=small_channels[3], norm_layer=norm_layer)
    self.deconv1 = nn.Sequential(*[SpectralNorm(nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)), nn.LeakyReLU(0.2, True)])
    self.deconv2 = nn.Sequential(*[SpectralNorm(nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)), nn.LeakyReLU(0.2, True)])
    self.conv2 = nn.Sequential(*[nn.Conv2d(128, 1, kernel_size=1, stride=1)])
    self.logsoftmax = nn.LogSoftmax(dim=1)

  def forward(self, x, proposal):
    n,c,h,w = x.size()
    if self.temperature == -1: return torch.ones((n,1,h,w)).cuda()
    scale32 = proposal[2][3]
    scale16 = torch.cat((proposal[2][1], proposal[1][3]),1) if self.fused == 2 else proposal[1][3]
    scale8 = torch.cat((proposal[0][3], proposal[1][1], proposal[2][0]),1) if self.fused == 2 else proposal[0][3]
    scale4 = torch.cat((proposal[0][1], proposal[1][0]),1) if self.fused == 2 else proposal[0][1]
    scale2 = proposal[0][0]
    out = self.up32_16(scale32, scale16)
    out = self.up16_8(out, scale8)
    out = self.up8_4(out, scale4)
    # out = self.up4_2(out, scale2)
    out = self.deconv1(out) 
    out = self.deconv2(out) 
    out = self.conv2(out)

    if not self.use_gumbel:
      logits = self.logsoftmax(out.view(n, -1))
      th, _ = torch.topk(logits, k=int(self.temperature), dim=1, largest=True)
      mask, zeros, ones = torch.zeros_like(logits).cuda(), torch.zeros(h*w).cuda(), torch.ones(h*w).cuda()
      for i in range(n):
        mask[i,:] = torch.where(logits[i,:]>=th[i, int(self.temperature)-1], ones, zeros)
      mask = mask.view(n, 1, h, w)
    elif self.use_gumbel:
      logits = gumbel_softmax(out.view(n, -1), k=int(self.temperature), T=self.T, hard=True, eps=1e-10).view(n, 1, h, w)
      mask = logits.cuda()
      # logits = F.gumbel_softmax(out.view(n, -1), tau=self.temperature).view(n, 1, h, w)
      # # logits_normed = torch.clamp((logits_normed+1e-4), min=0, max=1)
      # logits = np.minimum(1.0, logits.data.cpu().numpy()*(h*w)+1e-4)
      # mask = torch.bernoulli(torch.from_numpy(logits)).cuda()
    return mask

class UpLayer(nn.Module):
  def __init__(self, big_channels, out_channels, small_channels, norm_layer):
    super(UpLayer, self).__init__()
    self.big_channels = big_channels
    self.out_channels = out_channels
    self.small_channels = small_channels
    self.conv1 = nn.Sequential(*[SpectralNorm(nn.Conv2d(self.big_channels, self.small_channels, kernel_size=1, stride=1)), norm_layer(self.small_channels), nn.LeakyReLU(0.2, True)])
    self.conv2 = nn.Sequential(*[SpectralNorm(nn.Conv2d(self.small_channels, self.out_channels, kernel_size=3, stride=1, padding=1)), norm_layer(self.out_channels), nn.LeakyReLU(0.2, True)])

  def forward(self, small, big):
    small = F.upsample(small, size=(big.size()[2], big.size()[3]), mode='bilinear')
    big = self.conv1(big)
    out = self.conv2(big+small)
    return out

class Generator(nn.Module):
  def __init__(self, input_nc, output_nc, ngf, norm='bn', pool_dim=2048, n_blocks=6):
    super(Generator, self).__init__()

    n_downsampling = n_upsampling = 2
    use_bias = norm == 'in'
    norm_layer = nn.BatchNorm2d if norm == 'bn' else nn.InstanceNorm2d
    begin_layers, down_layers, res_layers, up_layers, end_layers = [], [], [], [], []
    for i in range(n_upsampling): 
      up_layers.append([])
    # ngf
    begin_layers = [nn.ReflectionPad2d(3), SpectralNorm(nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias)), norm_layer(ngf), nn.ReLU(True)]
    # 2ngf, 4ngf
    for i in range(n_downsampling):
      mult = 2**i
      down_layers += [SpectralNorm(nn.Conv2d(ngf*mult, ngf*mult*2, kernel_size=3, stride=2, padding=1, bias=use_bias)), norm_layer(ngf*mult*2), nn.ReLU(True)]
    # 4ngf
    mult = 2**n_downsampling
    for i in range(n_blocks):
      res_layers += [ResnetBlock(ngf*mult, norm_layer, use_bias)]
    # 2ngf, ngf
    for i in range(n_upsampling):
      mult = 2**(n_upsampling - i)
      up_layers[i] += [SpectralNorm(nn.ConvTranspose2d(ngf*mult, int(ngf*mult/2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias)), norm_layer(int(ngf*mult/2)), nn.ReLU(True)]

    end_layers += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]

    self.l1 = nn.Sequential(*begin_layers)
    self.l2 = nn.Sequential(*down_layers)
    self.l3 = nn.Sequential(*res_layers)
    self.l4_1 = nn.Sequential(*up_layers[0])
    self.l4_2 = nn.Sequential(*up_layers[1])
    self.l5 = nn.Sequential(*end_layers)

    self.channel_in_G = nn.Conv2d(128, pool_dim, kernel_size=1, stride=1)
    self.channel_out_G = nn.Conv2d(pool_dim, 128, kernel_size=1, stride=1)

  def forward_encode(self, inputs):
    out = self.l1(inputs)
    out = self.l2(out)
    out = self.l3(out)
    out = self.channel_in_G(out)
    return out

  def forward_decode(self, out):
    out = self.channel_out_G(out)
    out = self.l4_1(out)
    out = self.l4_2(out)
    out = self.l5(out)
    return out

  def forward(self, input, En_or_De = 'E'):
    if En_or_De == 'E':
      return self.forward_encode(input)
    elif En_or_De == 'D':
      return self.forward_decode(input)
    else:
      raise ValueError('Unknown flow')

class ResnetG(nn.Module):
  def __init__(self, input_nc, output_nc, ngf, norm='bn', n_blocks=6):
    super(ResnetG, self).__init__()

    n_downsampling = n_upsampling = 2
    use_bias = norm == 'in'
    norm_layer = nn.BatchNorm2d if norm == 'bn' else nn.InstanceNorm2d
    begin_layers, down_layers, res_layers, up_layers, end_layers = [], [], [], [], []
    for i in range(n_upsampling): 
      up_layers.append([])
    # ngf
    begin_layers = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True)]
    # 2ngf, 4ngf
    for i in range(n_downsampling):
      mult = 2**i
      down_layers += [nn.Conv2d(ngf*mult, ngf*mult*2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(ngf*mult*2), nn.ReLU(True)]
    # 4ngf
    mult = 2**n_downsampling
    for i in range(n_blocks):
      res_layers += [ResnetBlock(ngf*mult, norm_layer, use_bias)]
    # 2ngf, ngf
    for i in range(n_upsampling):
      mult = 2**(n_upsampling - i)
      up_layers[i] += [nn.ConvTranspose2d(ngf*mult, int(ngf*mult/2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), norm_layer(int(ngf*mult/2)), nn.ReLU(True)]

    end_layers += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]

    self.l1 = nn.Sequential(*begin_layers)
    self.l2 = nn.Sequential(*down_layers)
    self.l3 = nn.Sequential(*res_layers)
    self.l4_1 = nn.Sequential(*up_layers[0])
    self.l4_2 = nn.Sequential(*up_layers[1])
    self.l5 = nn.Sequential(*end_layers)

  def forward(self, inputs):
    out = self.l1(inputs)
    out = self.l2(out)
    out = self.l3(out)
    out = self.l4_1(out)
    out = self.l4_2(out)
    out = self.l5(out)
    return out

# Define a resnet block
class ResnetBlock(nn.Module):
  def __init__(self, dim, norm_layer, use_bias):
    super(ResnetBlock, self).__init__()
    self.conv_block = self.build_conv_block(dim, norm_layer, use_bias)

  def build_conv_block(self, dim, norm_layer, use_bias):
    conv_block = []
    for i in range(2):
      conv_block += [nn.ReflectionPad2d(1)]
      conv_block += [SpectralNorm(nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=use_bias)), norm_layer(dim)]
      if i < 1: 
        conv_block += [nn.ReLU(True)]
    return nn.Sequential(*conv_block)

  def forward(self, x):
    out = x + self.conv_block(x)
    return out

class SamePadding(nn.Module):
  def __init__(self, kernel_size, stride):
    super(SamePadding, self).__init__()
    self.kernel_size = torch.nn.modules.utils._pair(kernel_size)
    self.stride = torch.nn.modules.utils._pair(stride)

  def forward(self, input):
    in_width = input.size()[2]
    in_height = input.size()[3]
    out_width = math.ceil(float(in_width) / float(self.stride[0]))
    out_height = math.ceil(float(in_height) / float(self.stride[1]))
    pad_along_width = ((out_width - 1) * self.stride[0] +
                       self.kernel_size[0] - in_width)
    pad_along_height = ((out_height - 1) * self.stride[1] +
                        self.kernel_size[1] - in_height)
    pad_left = int(pad_along_width / 2)
    pad_top = int(pad_along_height / 2)
    pad_right = pad_along_width - pad_left
    pad_bottom = pad_along_height - pad_top
    return F.pad(input, (int(pad_left), int(pad_right), int(pad_top), int(pad_bottom)), 'constant', 0)

  def __repr__(self):
    return self.__class__.__name__

def weights_init(m):
  classname = m.__class__.__name__
  # print(dir(m))
  if classname.find('Conv') != -1:
    if 'weight' in dir(m): 
      m.weight.data.normal_(0.0, 1)
  elif classname.find('BatchNorm2d') != -1:
    m.weight.data.normal_(1.0, 0.02)
    m.bias.data.fill_(0)

class GANLoss_real_fake(nn.Module):
  def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.cuda.FloatTensor):
    super(GANLoss_real_fake, self).__init__()
    self.real_label = target_real_label
    self.fake_label = target_fake_label
    self.real_label_var = None
    self.fake_label_var = None
    self.Tensor = tensor
    if use_lsgan: self.loss = nn.MSELoss()
    else: self.loss = nn.BCELoss()

  def get_target_tensor(self, input, target_is_real):
    target_tensor = None
    if target_is_real:
      create_label = ((self.real_label_var is None) or
                      (self.real_label_var.numel() != input.numel()))
      if create_label:
        real_tensor = self.Tensor(input.size()).fill_(self.real_label)
        self.real_label_var = Variable(real_tensor, requires_grad=False)
      target_tensor = self.real_label_var
    else:
      create_label = ((self.fake_label_var is None) or
                      (self.fake_label_var.numel() != input.numel()))
      if create_label:
        fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
        self.fake_label_var = Variable(fake_tensor, requires_grad=False)
      target_tensor = self.fake_label_var
    return target_tensor

  def __call__(self, input, target_is_real):
    if isinstance(input[0], list):
      loss = 0
      for input_i in input:
        pred = input_i[-1]
        target_tensor = self.get_target_tensor(pred, target_is_real)
        loss += self.loss(pred, target_tensor)
      return loss
    else:            
      target_tensor = self.get_target_tensor(input[-1], target_is_real)
      return self.loss(input[-1], target_tensor)

class GANLoss(nn.Module):
    """Define different GAN objectives.
    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """
    def __init__(self, use_lsgan=True, target_real_label=0.0, target_fake_label=1.0):
        """ Initialize the GANLoss class.
        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image
        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
        self.use_lsgan = use_lsgan
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCEWithLogitsLoss()


    def get_target_tensor(self,  label,prediction, target_is_real):
        """Create label tensors with the same size as the input.
        Parameters:
            prediction (tensor) - - tpyically the prediction from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images
        Returns:
            A label tensor filled with ground truth label, and with the size of the input
        """
        s = prediction.size()

        if target_is_real:
            real_label = self.real_label.expand(label.size(0), 1).cuda()
            target_tensor = torch.cat([label, real_label], dim=-1)

        else:
            fake_label = self.fake_label.expand(label.size(0), 1).cuda()
            target_tensor = torch.cat([label, fake_label], dim=-1)

        target_tensor = target_tensor.view(s[0], s[1], 1, 1)
        target_tensor = target_tensor.repeat(1, 1, s[2], s[3])
        return target_tensor

    def __call__(self, prediction, label, target_is_real):
        """Calculate loss given Discriminator's output and grount truth labels.
        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images
        Returns:
            the calculated loss.
        """
        if self.use_lsgan:
          if isinstance(prediction[0], list):
            loss = 0
            for input_i in prediction:
              pred = input_i[-1]
              target_tensor = self.get_target_tensor(label, prediction,target_is_real)
              loss += self.loss(pred, target_tensor)
            return loss
          else:
            target_tensor = self.get_target_tensor(label, prediction[-1], target_is_real)
            return self.loss(prediction[-1], target_tensor)     

        else:
            if target_is_real:
                loss = -prediction.mean()
            else:
                loss = prediction.mean()
        return loss