CodeExamples / Face Recognition / models.py
models.py
Raw
import torch
from torch import nn
from torch.nn import MaxPool2d
from torch.nn import Sequential
from torch.nn import Conv2d, Linear
from torch.nn import BatchNorm2d
from torch.nn import ReLU, LeakyReLU
from torch.nn import Module

import torchvision as tv
from torchvision.transforms import functional as f

from visualisation import show_all_channels as show

# Define model
class ConvolutionNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.num_channels = 64
        self.ReLU = nn.ReLU()
        self.softmax = nn.Softmax2d()
        #self.norm = nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True)
    
        self.block = nn.Sequential(
            nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True),
            nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True),
            nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True),
            nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True),
            nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True),
            nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True),
            nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True),
            nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'),
            nn.ReLU()
        )
        
        self.c1 = nn.Conv2d(3, self.num_channels, 7, stride=2, padding='valid')    
        self.norm1 = nn.BatchNorm2d(3, affine=True, track_running_stats=True)
        self.c2 = nn.Conv2d(self.num_channels, self.num_channels, 5, padding='valid')
        self.norm2 = nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True)
        #self.ca = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid')
        #self.cb = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid')
        #self.cc = nn.Conv2d(self.num_channels, self.num_channels, 3,  padding='valid')
        #self.cd = nn.Conv2d(self.num_channels, self.num_channels, 3,  padding='valid')
        #self.ce = nn.Conv2d(self.num_channels, self.num_channels, 3,  padding='valid')
        #self.cf = nn.Conv2d(self.num_channels, self.num_channels, 3,  padding='valid')
        self.c3 = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid')      
        self.c4 = nn.Conv2d(self.num_channels, self.num_channels, 3, stride=1, padding='valid')
        self.lin = nn.Linear(64800*2, 100)
        #self.lin2 = nn.Linear(28*28, 2000)           
    
    def forward(self, x, isNorm=False, isPrint=False, visualize=False):
        if isPrint: print()
        if isPrint: print(x.shape)
        if visualize: show(x[0], 0)

        x1 = self.ReLU(self.c1(self.norm1(x)))
        if isNorm: x1 = self.norm(x1)
        if isPrint: print(x1.shape)
        if visualize: show(x1[0], 1)

        x2 = self.ReLU(self.c2(self.norm2(x1)))
        if isNorm: x2 = self.norm(x2)
        if isPrint: print(x2.shape)
        if visualize: show(x2[0], 2)

        x2 = self.block(x2) #self.ReLU(self.ca(x2))
        if visualize: show(x2[0], 3)
        #x2 = self.ReLU(self.cb(x2))
       # if visualize: show(x2[0], 4)
        #x2 = self.ReLU(self.cc(x2))
        #if visualize: show(x2[0], 5)
        #x2 = self.ReLU(self.cd(x2))
        #if visualize: show(x2[0], 6)
        #x2 = self.ReLU(self.ce(x2))
        #if visualize: show(x2[0], 7)
        #x2 = self.ReLU(self.cf(x2))
        #if visualize: show(x2[0], 8)

        x3 = self.ReLU(self.c3(x2))
        if isNorm: x3 = self.norm(x3)
        if isPrint: print(x3.shape)
        if visualize: show(x3[0], 9)

        x4 = self.ReLU(self.c4(x3))
        if isNorm: x4 = self.norm(x4)
        if isPrint: print(x4.shape)
        if visualize: show(x4[0], 10)
        
        x_hat = self.flatten(x4)
        if isPrint: print(x_hat.shape)
        if isPrint: print('---')
        logits = self.lin(x_hat)   
        return logits     



class BasicBlockIR(Module):
    """ BasicBlock for IRNet
    """
    def __init__(self, in_channel, depth, stride):
        super(BasicBlockIR, self).__init__()
        if in_channel == depth:
            self.shortcut_layer = MaxPool2d(1, stride)
        else:
            self.shortcut_layer = Sequential(
                Conv2d(in_channel, depth, (1, 1), stride, bias=False),
                BatchNorm2d(depth))
        self.res_layer = Sequential(
            BatchNorm2d(in_channel),
            Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
            BatchNorm2d(depth),
            LeakyReLU(),
            Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
            BatchNorm2d(depth))

    def forward(self, x):
        shortcut = self.shortcut_layer(x)
        res = self.res_layer(x)

        return res + shortcut


class Backbone(Module):
    def __init__(self, input_size=3, h_size=64, mode='ir'):
        """ Args:
            input_size: input_size of backbone
            num_layers: num_layers of backbone
            mode: support ir or irse
        """
        super(Backbone, self).__init__()

        self.input_layer = Sequential(
            Conv2d(input_size, h_size, (3, 3), 1, 1, bias=False),
            BatchNorm2d(h_size),
            LeakyReLU(h_size)
        )
        
        stride = 2
        self.blocks = Sequential(
            BasicBlockIR(in_channel=h_size, depth=h_size, stride=stride),
            BasicBlockIR(in_channel=h_size, depth=h_size, stride=stride),
            BasicBlockIR(in_channel=h_size, depth=h_size*2, stride=stride),
            BasicBlockIR(in_channel=h_size*2, depth=h_size*2, stride=stride),
            BasicBlockIR(in_channel=h_size*2, depth=h_size*4, stride=stride),
            BasicBlockIR(in_channel=h_size*4, depth=h_size*4, stride=stride),
            BasicBlockIR(in_channel=h_size*4, depth=h_size*8, stride=stride),
            BasicBlockIR(in_channel=h_size*8, depth=h_size*8, stride=stride)
        )

        output_channel = h_size*8
        self.output_norm = BatchNorm2d(output_channel)        
        self.output_layer = Linear(output_channel, 100)      


    def forward(self, x):
        
        # current code only supports one extra image
        # it comes with a extra dimension for number of extra image. We will just squeeze it out for now
        x = self.input_layer(x)

        x = self.blocks(x)

        #print(x.shape)
        x = self.output_norm(x)
        #print(x.shape)
        x = torch.squeeze(x)
        #print(x.shape)
        x = self.output_layer(x)
        #norm = torch.norm(x, 2, 1, True)
        #output = torch.div(x, norm)

        return x #output, norm