Generative-Adversarial-Network / gan / spectral_normalization.py
spectral_normalization.py
Raw
import torch
from torch.optim.optimizer import Optimizer, required

from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from torch.nn import Parameter

def l2normalize(v, eps=1e-12):
    """
    TODO: Implement L2 normalization. 
    """
    
    # --- different from L2 norm only -- division required
    #calculate L2 norms
    normm=torch.norm(v)        #torch.norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None)
    
    
    normm+=eps #piazza advice
    #divide by the L2 norm
    ration=v/normm
    
    return ration
    


class SpectralNorm(nn.Module):
    def __init__(self, module, name='weight', power_iterations=1):
        """
        Reference:
        SPECTRAL NORMALIZATION FOR GENERATIVE ADVERSARIAL NETWORKS: https://arxiv.org/pdf/1802.05957.pdf
        """
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        self._make_params()

    def _update_u_v(self):
        """
        TODO: Implement Spectral Normalization
        Hint: 1: Use getattr to first extract u, v, w.
              2: Apply power iteration.
              3: Calculate w with the spectral norm.
              4: Use setattr to update w in the module.
        """
        #print(self.name)
        #print(self.name + "_u")
        u = getattr(self.module, self.name + "_u")   # hint from below -- self.module.register_parameter(self.name + "_u", u)
        #print(self.name + "_v")
        v = getattr(self.module, self.name + "_v")   #with a random vector (sampled from isotropic distri-bution).
        #print(self.name + "_bar")
        w = getattr(self.module, self.name + "_bar")  #piazza hint
        ####################################################################################
        #  In fact,with this ‘recycle’ procedure, one round of power iteration was sufficient
        #  in the actual experimentto achieve satisfactory performanc
        # 2: Apply power iteration.
        ####################################################################################
        #print("size of w",w.size())
        #print("size of u",u.size())
         
        w_new=w.reshape(w.size()[0],-1).detach()  # change to 2D and proceed -- hint from office hours and few medium blog sites.
        #print("size of w new",w_new.size())     
        #tarnspose w
        w_new=torch.transpose(w_new,0,1)
        deno_numo= torch.matmul(w_new,u)  #MATMUL
        v_update=l2normalize(deno_numo) 
        #norm_deno=l2normalize(deno_numo)   #l2 normalize does the division
        #v_update= deno_numo/norm_deno     
        
        w_new=w.reshape(w.size()[0],-1).detach()  # change to 2D 
        deno_numo= torch.matmul( w_new, v ) #no transposing here as per the paper
        u_update=l2normalize(deno_numo)  
        #norm_deno=l2normalize(deno_numo)  
        #u_update= deno_numo/norm_deno     #l2 normalize does the division
        #

        w_new=w.view(w.size()[0],-1).detach()   # change to 2D 
        sigma_tmp=torch.matmul(w_new,v_update)   # tmp sigma is the left multiplication
        sigma = torch.matmul(u,sigma_tmp)   # multiplication is supposed to be with transpose of u -- dot product
        w=w/sigma      #Calculate ̄WSNwith the spectral norm
        
        
        
        setattr(self.module, self.name, w)   #from piazza 4: Use setattr to update w in the module.
        
        
        


    def _make_params(self):
        """
        No need to change. Initialize parameters.
        v: Initialize v with a random vector (sampled from isotropic distrition).
        u: Initialize u with a random vector (sampled from isotropic distrition).
        w: Weight of the current layer.
        """
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data) # needs division in the function abouve
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        #print(self.name + "_u")
        self.module.register_parameter(self.name + "_v", v)
        #print(self.name + "_v")
        self.module.register_parameter(self.name + "_bar", w_bar)
        #print(self.name + "_w")


    def forward(self, *args):
        """
        No need to change. Update weights using spectral normalization.
        """
        self._update_u_v()
        return self.module.forward(*args)