# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import torch import torch.nn as nn ''' This version includes batch norm between layers ''' def get_act_fn(name): if name == 'relu': fn = torch.nn.ReLU elif name == 'gelu': fn = torch.nn.GELU elif name == 'prelu': fn = torch.nn.PReLU elif name == 'tanh': fn = torch.nn.Tanh elif name == 'leakyrelu': fn = torch.nn.LeakyReLU else: raise return fn class ResBlock(nn.Module): def __init__(self, in_channel, channel, kernels, act='relu'): super().__init__() self.conv = nn.Sequential( get_act_fn(act)(), nn.Conv2d(in_channel, channel, kernels[0], padding=1), nn.BatchNorm2d(channel), get_act_fn(act)(), nn.Conv2d(channel, in_channel, kernels[1]), nn.BatchNorm2d(in_channel), ) def forward(self, x): out = self.conv(x) out += x return out class Encoder(torch.nn.Module): def __init__(self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, kernels, res_kernels, act='leakyrelu'): super().__init__() self.out_channel = out_channel if stride == 10: blocks = [ nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.Conv2d(channel // 2, channel // 2, kernels[1], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.Conv2d(channel // 2, channel // 2, kernels[2], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.Conv2d(channel // 2, channel// 2, kernels[3], stride=2, padding=1), nn.BatchNorm2d(channel// 2), get_act_fn(act)(), nn.Conv2d(channel // 2, channel, kernels[4], stride=2, padding=1), nn.BatchNorm2d(channel), get_act_fn(act)(), nn.Conv2d(channel, channel, kernels[5], padding=1), nn.BatchNorm2d(channel), ] elif stride == 8: blocks = [ nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.Conv2d(channel // 2, channel // 2, kernels[1], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.Conv2d(channel // 2, channel // 2, kernels[2], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.Conv2d(channel // 2, channel, kernels[3], stride=2, padding=1), nn.BatchNorm2d(channel), get_act_fn(act)(), nn.Conv2d(channel, channel, kernels[4], padding=1), nn.BatchNorm2d(channel), ] elif stride == 6: blocks = [ nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.Conv2d(channel // 2, channel // 2, kernels[1], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.Conv2d(channel // 2, channel, kernels[2], stride=2, padding=1), nn.BatchNorm2d(channel), get_act_fn(act)(), nn.Conv2d(channel, channel, kernels[3], padding=1), nn.BatchNorm2d(channel), ] elif stride == 4: blocks = [ nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.Conv2d(channel // 2, channel, kernels[1], stride=2, padding=1), nn.BatchNorm2d(channel), get_act_fn(act)(), nn.Conv2d(channel, channel, kernels[2], padding=1), nn.BatchNorm2d(channel), ] elif stride == 2: # if kernels[0] == 3: # padding = 1 # else: # padding = 0 blocks = [ nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.Conv2d(channel // 2, channel, kernels[1], padding=1), nn.BatchNorm2d(channel) ] elif stride == 1: blocks = [ nn.Conv2d(in_channel, channel, 3, padding=1), nn.BatchNorm2d(channel) ] for i in range(n_res_block): blocks.append(ResBlock(channel, n_res_channel, res_kernels, act=act)) blocks.append(get_act_fn(act)()) blocks.append(nn.Conv2d(channel, out_channel, 1)) self.blocks = nn.Sequential(*blocks) def forward(self, x): return self.blocks(x) class Decoder(nn.Module): def __init__( self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, kernels, res_kernels, act='relu', padding=1, ablation_prob=0.0 ): super().__init__() self.ablation_prob = ablation_prob self.in_channel = in_channel blocks = [nn.Conv2d(in_channel, channel, kernels[0], padding=padding)] for i in range(n_res_block): blocks.append(ResBlock(channel, n_res_channel, res_kernels, act)) blocks.append(get_act_fn(act)()) if stride == 10: blocks.extend( [ nn.ConvTranspose2d(channel, channel // 2, kernels[1], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.ConvTranspose2d(channel//2, channel // 2, kernels[2], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.ConvTranspose2d(channel//2, channel // 2, kernels[3], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.ConvTranspose2d(channel//2, channel // 2, kernels[4], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)() ] ) self.conv_out = nn.ConvTranspose2d(channel // 2, out_channel, kernels[5], stride=2, padding=1) elif stride == 8: blocks.extend( [ nn.ConvTranspose2d(channel, channel // 2, kernels[1], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.ConvTranspose2d(channel//2, channel // 2, kernels[2], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.ConvTranspose2d(channel//2, channel // 2, kernels[3], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)() ] ) self.conv_out = nn.ConvTranspose2d(channel // 2, out_channel, kernels[4], stride=2, padding=1) elif stride == 6: blocks.extend( [ nn.ConvTranspose2d(channel, channel // 2, kernels[1], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)(), nn.ConvTranspose2d(channel//2, channel // 2, kernels[2], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)() ] ) self.conv_out = nn.ConvTranspose2d(channel // 2, out_channel, kernels[3], stride=2, padding=1) elif stride == 4: blocks.extend( [ nn.ConvTranspose2d(channel, channel // 2, kernels[1], stride=2, padding=1), nn.BatchNorm2d(channel // 2), get_act_fn(act)() ] ) self.conv_out = nn.ConvTranspose2d(channel // 2, out_channel, kernels[2], stride=2, padding=1) elif stride == 2: self.conv_out = nn.ConvTranspose2d(channel, out_channel, kernels[1], stride=2, padding=1, output_padding=1) elif stride == 1: self.conv_out = nn.ConvTranspose2d(channel, out_channel, kernels[1]) self.blocks = nn.Sequential(*blocks) self.final = nn.Sigmoid() def forward(self, x): x = self.blocks(x) return self.final(self.conv_out(x))