import torch from torch import nn from torch.nn import MaxPool2d from torch.nn import Sequential from torch.nn import Conv2d, Linear from torch.nn import BatchNorm2d from torch.nn import ReLU, LeakyReLU from torch.nn import Module import torchvision as tv from torchvision.transforms import functional as f from visualisation import show_all_channels as show # Define model class ConvolutionNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.num_channels = 64 self.ReLU = nn.ReLU() self.softmax = nn.Softmax2d() #self.norm = nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True) self.block = nn.Sequential( nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True), nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'), nn.ReLU(), nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True), nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'), nn.ReLU(), nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True), nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'), nn.ReLU(), nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True), nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'), nn.ReLU(), nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True), nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'), nn.ReLU(), nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True), nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'), nn.ReLU(), nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True), nn.Conv2d(self.num_channels, self.num_channels, 3, padding='same'), nn.ReLU() ) self.c1 = nn.Conv2d(3, self.num_channels, 7, stride=2, padding='valid') self.norm1 = nn.BatchNorm2d(3, affine=True, track_running_stats=True) self.c2 = nn.Conv2d(self.num_channels, self.num_channels, 5, padding='valid') self.norm2 = nn.BatchNorm2d(self.num_channels, affine=True, track_running_stats=True) #self.ca = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid') #self.cb = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid') #self.cc = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid') #self.cd = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid') #self.ce = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid') #self.cf = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid') self.c3 = nn.Conv2d(self.num_channels, self.num_channels, 3, padding='valid') self.c4 = nn.Conv2d(self.num_channels, self.num_channels, 3, stride=1, padding='valid') self.lin = nn.Linear(64800*2, 100) #self.lin2 = nn.Linear(28*28, 2000) def forward(self, x, isNorm=False, isPrint=False, visualize=False): if isPrint: print() if isPrint: print(x.shape) if visualize: show(x[0], 0) x1 = self.ReLU(self.c1(self.norm1(x))) if isNorm: x1 = self.norm(x1) if isPrint: print(x1.shape) if visualize: show(x1[0], 1) x2 = self.ReLU(self.c2(self.norm2(x1))) if isNorm: x2 = self.norm(x2) if isPrint: print(x2.shape) if visualize: show(x2[0], 2) x2 = self.block(x2) #self.ReLU(self.ca(x2)) if visualize: show(x2[0], 3) #x2 = self.ReLU(self.cb(x2)) # if visualize: show(x2[0], 4) #x2 = self.ReLU(self.cc(x2)) #if visualize: show(x2[0], 5) #x2 = self.ReLU(self.cd(x2)) #if visualize: show(x2[0], 6) #x2 = self.ReLU(self.ce(x2)) #if visualize: show(x2[0], 7) #x2 = self.ReLU(self.cf(x2)) #if visualize: show(x2[0], 8) x3 = self.ReLU(self.c3(x2)) if isNorm: x3 = self.norm(x3) if isPrint: print(x3.shape) if visualize: show(x3[0], 9) x4 = self.ReLU(self.c4(x3)) if isNorm: x4 = self.norm(x4) if isPrint: print(x4.shape) if visualize: show(x4[0], 10) x_hat = self.flatten(x4) if isPrint: print(x_hat.shape) if isPrint: print('---') logits = self.lin(x_hat) return logits class BasicBlockIR(Module): """ BasicBlock for IRNet """ def __init__(self, in_channel, depth, stride): super(BasicBlockIR, self).__init__() if in_channel == depth: self.shortcut_layer = MaxPool2d(1, stride) else: self.shortcut_layer = Sequential( Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth)) self.res_layer = Sequential( BatchNorm2d(in_channel), Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), BatchNorm2d(depth), LeakyReLU(), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)) def forward(self, x): shortcut = self.shortcut_layer(x) res = self.res_layer(x) return res + shortcut class Backbone(Module): def __init__(self, input_size=3, h_size=64, mode='ir'): """ Args: input_size: input_size of backbone num_layers: num_layers of backbone mode: support ir or irse """ super(Backbone, self).__init__() self.input_layer = Sequential( Conv2d(input_size, h_size, (3, 3), 1, 1, bias=False), BatchNorm2d(h_size), LeakyReLU(h_size) ) stride = 2 self.blocks = Sequential( BasicBlockIR(in_channel=h_size, depth=h_size, stride=stride), BasicBlockIR(in_channel=h_size, depth=h_size, stride=stride), BasicBlockIR(in_channel=h_size, depth=h_size*2, stride=stride), BasicBlockIR(in_channel=h_size*2, depth=h_size*2, stride=stride), BasicBlockIR(in_channel=h_size*2, depth=h_size*4, stride=stride), BasicBlockIR(in_channel=h_size*4, depth=h_size*4, stride=stride), BasicBlockIR(in_channel=h_size*4, depth=h_size*8, stride=stride), BasicBlockIR(in_channel=h_size*8, depth=h_size*8, stride=stride) ) output_channel = h_size*8 self.output_norm = BatchNorm2d(output_channel) self.output_layer = Linear(output_channel, 100) def forward(self, x): # current code only supports one extra image # it comes with a extra dimension for number of extra image. We will just squeeze it out for now x = self.input_layer(x) x = self.blocks(x) #print(x.shape) x = self.output_norm(x) #print(x.shape) x = torch.squeeze(x) #print(x.shape) x = self.output_layer(x) #norm = torch.norm(x, 2, 1, True) #output = torch.div(x, norm) return x #output, norm