import torch
from torch import nn
from torch.nn import MaxPool2d
from torch.nn import Sequential
from torch.nn import Conv2d, Linear
from torch.nn import BatchNorm2d
from torch.nn import ReLU, LeakyReLU
from torch.nn import Module
import torchvision as tv
from torchvision.transforms import functional as f
from visualisation import show_all_channels as show
# Define model
class ConvolutionNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.num_channels = 64
self.ReLU = nn.ReLU()
self.softmax = nn.Softmax2d()
#self.norm = nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True)
self.block = nn.Sequential(
nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True),
nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'),
nn.ReLU(),
nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True),
nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'),
nn.ReLU(),
nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True),
nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'),
nn.ReLU(),
nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True),
nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'),
nn.ReLU(),
nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True),
nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'),
nn.ReLU(),
nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True),
nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'),
nn.ReLU(),
nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True),
nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'),
nn.ReLU()
)
self.c1 = nn.Conv2d(3, self.num_channels, 7, stride=2, padding='valid')
self.norm1 = nn.BatchNorm2d(3, affine=True, track_running_stats=True)
self.c2 = nn.Conv2d(self.num_channels, self.num_channels, 5, padding='valid')
self.norm2 = nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True)
#self.ca = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid')
#self.cb = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid')
#self.cc = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid')
#self.cd = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid')
#self.ce = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid')
#self.cf = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid')
self.c3 = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid')
self.c4 = nn.Conv2d(self.num_channels, self.num_channels, 3, stride=1, padding='valid')
self.lin = nn.Linear(64800*2, 100)
#self.lin2 = nn.Linear(28*28, 2000)
def forward(self, x, isNorm=False, isPrint=False, visualize=False):
if isPrint: print()
if isPrint: print(x.shape)
if visualize: show(x[0], 0)
x1 = self.ReLU(self.c1(self.norm1(x)))
if isNorm: x1 = self.norm(x1)
if isPrint: print(x1.shape)
if visualize: show(x1[0], 1)
x2 = self.ReLU(self.c2(self.norm2(x1)))
if isNorm: x2 = self.norm(x2)
if isPrint: print(x2.shape)
if visualize: show(x2[0], 2)
x2 = self.block(x2) #self.ReLU(self.ca(x2))
if visualize: show(x2[0], 3)
#x2 = self.ReLU(self.cb(x2))
# if visualize: show(x2[0], 4)
#x2 = self.ReLU(self.cc(x2))
#if visualize: show(x2[0], 5)
#x2 = self.ReLU(self.cd(x2))
#if visualize: show(x2[0], 6)
#x2 = self.ReLU(self.ce(x2))
#if visualize: show(x2[0], 7)
#x2 = self.ReLU(self.cf(x2))
#if visualize: show(x2[0], 8)
x3 = self.ReLU(self.c3(x2))
if isNorm: x3 = self.norm(x3)
if isPrint: print(x3.shape)
if visualize: show(x3[0], 9)
x4 = self.ReLU(self.c4(x3))
if isNorm: x4 = self.norm(x4)
if isPrint: print(x4.shape)
if visualize: show(x4[0], 10)
x_hat = self.flatten(x4)
if isPrint: print(x_hat.shape)
if isPrint: print('---')
logits = self.lin(x_hat)
return logits
class BasicBlockIR(Module):
""" BasicBlock for IRNet
"""
def __init__(self, in_channel, depth, stride):
super(BasicBlockIR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
BatchNorm2d(depth),
LeakyReLU(),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class Backbone(Module):
def __init__(self, input_size=3, h_size=64, mode='ir'):
""" Args:
input_size: input_size of backbone
num_layers: num_layers of backbone
mode: support ir or irse
"""
super(Backbone, self).__init__()
self.input_layer = Sequential(
Conv2d(input_size, h_size, (3, 3), 1, 1, bias=False),
BatchNorm2d(h_size),
LeakyReLU(h_size)
)
stride = 2
self.blocks = Sequential(
BasicBlockIR(in_channel=h_size, depth=h_size, stride=stride),
BasicBlockIR(in_channel=h_size, depth=h_size, stride=stride),
BasicBlockIR(in_channel=h_size, depth=h_size*2, stride=stride),
BasicBlockIR(in_channel=h_size*2, depth=h_size*2, stride=stride),
BasicBlockIR(in_channel=h_size*2, depth=h_size*4, stride=stride),
BasicBlockIR(in_channel=h_size*4, depth=h_size*4, stride=stride),
BasicBlockIR(in_channel=h_size*4, depth=h_size*8, stride=stride),
BasicBlockIR(in_channel=h_size*8, depth=h_size*8, stride=stride)
)
output_channel = h_size*8
self.output_norm = BatchNorm2d(output_channel)
self.output_layer = Linear(output_channel, 100)
def forward(self, x):
# current code only supports one extra image
# it comes with a extra dimension for number of extra image. We will just squeeze it out for now
x = self.input_layer(x)
x = self.blocks(x)
#print(x.shape)
x = self.output_norm(x)
#print(x.shape)
x = torch.squeeze(x)
#print(x.shape)
x = self.output_layer(x)
#norm = torch.norm(x, 2, 1, True)
#output = torch.div(x, norm)
return x #output, norm