Generative-Adversarial-Network / gan / models.py
models.py
Raw
import torch
from gan.spectral_normalization import SpectralNorm
import torch.nn as nn
#from spectral_normalization import SpectralNorm

class Discriminator(torch.nn.Module):
    def __init__(self, input_channels=3):
        super(Discriminator, self).__init__()
        
        #Hint: Hint: Apply spectral normalization to convolutional layers. Input to SpectralNorm should be your conv nn module
        ####################################
        #          YOUR CODE HERE          #
        ####################################
       self.features = nn.Sequential(
               SpectralNorm(nn.Conv2d(input_channels, 128, 4, stride=2, padding=(1,1))),
               nn.LeakyReLU(negative_slope=0.02),
               SpectralNorm(nn.Conv2d(128, 256, 4, stride=2, padding=(1,1))),
               nn.BatchNorm2d(256 ),
               nn.LeakyReLU(negative_slope=0.02),
               SpectralNorm(nn.Conv2d(256, 512, 4, stride=2, padding=(1,1))),
               nn.BatchNorm2d(512 ),
               nn.LeakyReLU(negative_slope=0.02),
               SpectralNorm(nn.Conv2d(512, 1024, 4, stride=2, padding=(1,1))),
               nn.BatchNorm2d(1024 ),
               nn.LeakyReLU(negative_slope=0.02),
               SpectralNorm(nn.Conv2d(1024, 1, 4, stride=1, padding=(1,1))),
       )
        ##########       END      ##########
    
   def forward(self, x):
        
        ###################################
                 YOUR CODE HERE          #
        ###################################
        
        
        #########       END      ##########
       x = self.features(x)
        print("this is the output",x)
        print("discriminator output", x.shape)
       return x

# class Discriminator(torch.nn.Module):
    # def __init__(self, input_channels=3):
        # super(Discriminator, self).__init__()
        
        # #Hint: Hint: Apply spectral normalization to convolutional layers. Input to SpectralNorm should be your conv nn module
        # ####################################
        # #          YOUR CODE HERE          #
        # ####################################
        # self.features = nn.Sequential(
                # nn.Conv2d(input_channels, 128, 4, stride=2, padding=(1,1)),
                # nn.LeakyReLU(negative_slope=0.02),
                # nn.Conv2d(128, 256, 4, stride=2, padding=(1,1)),
                # nn.BatchNorm2d(256 ),
                # nn.LeakyReLU(negative_slope=0.02),
                # nn.Conv2d(256, 512, 4, stride=2, padding=(1,1)),
                # nn.BatchNorm2d(512 ),
                # nn.LeakyReLU(negative_slope=0.02),
                # nn.Conv2d(512, 1024, 4, stride=2, padding=(1,1)),
                # nn.BatchNorm2d(1024 ),
                # nn.LeakyReLU(negative_slope=0.02),
                # nn.Conv2d(1024, 1, 4, stride=1, padding=(1,1)),
        # )
        # ##########       END      ##########
    
    # def forward(self, x):
        
        # ####################################
        # #          YOUR CODE HERE          #
        # ####################################
        
        
        # ##########       END      ##########
        # x = self.features(x)
        # #print("this is the output",x)
        # #print("discriminator output", x.shape)
        # return x


class Generator(torch.nn.Module):
    def __init__(self, noise_dim, output_channels=3):
        super(Generator, self).__init__()    
        self.noise_dim = noise_dim
        self.features = nn.Sequential(
                        nn.ConvTranspose2d(noise_dim, 1024, 4, stride=1), #TA on piazza asked to add 
                        nn.BatchNorm2d(1024),
                        nn.ReLU(True),   # piazza advice to get better results
                        nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=(1,1)),  #TA on piazza asked to add 
                        nn.BatchNorm2d(512),
                        nn.ReLU(True),  # piazza advice to get better results
                        nn.ConvTranspose2d(512, 256, 4, stride=2, padding=(1,1)),  #TA on piazza asked to add 
                        nn.BatchNorm2d(256),
                        nn.ReLU(True),  # piazza advice to get better results
                        nn.ConvTranspose2d(256, 128, 4, stride=2, padding=(1,1)),
                        nn.BatchNorm2d(128), 
                        nn.ReLU(True),
                        nn.ConvTranspose2d(128, 3, 4, stride=2, padding=(1,1)),
                        nn.Tanh()
        
        )
        ####################################
        #          YOUR CODE HERE          #
        ####################################
        
        
        ##########       END      ##########
    
    def forward(self, x):
        
        ####################################
        #          YOUR CODE HERE          #
        ####################################
        x = self.features(x)
        
        
        ##########       END      ##########
        
        #print("generator output", x.shape)
        return x