# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import torch import torch.nn as nn def get_act_fn(name): if name == 'relu': fn = torch.nn.ReLU elif name == 'gelu': fn = torch.nn.GELU elif name == 'prelu': fn = torch.nn.PReLU elif name == 'tanh': fn = torch.nn.Tanh elif name == 'leakyrelu': fn = torch.nn.LeakyReLU else: raise return fn class ResBlock(nn.Module): def __init__(self, in_channel, channel, kernels, act='relu'): super().__init__() self.conv = nn.Sequential( get_act_fn(act)(), nn.Conv2d(in_channel, channel, kernels[0], padding=1), get_act_fn(act)(), nn.Conv2d(channel, in_channel, kernels[1]), ) def forward(self, x): out = self.conv(x) out += x return out class Encoder(torch.nn.Module): def __init__(self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, kernels, res_kernels, act='leakyrelu'): super().__init__() self.out_channel = out_channel if stride == 8: blocks = [ nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1), get_act_fn(act)(), nn.Conv2d(channel // 2, channel // 2, kernels[1], stride=2, padding=1), get_act_fn(act)(), nn.Conv2d(channel // 2, channel // 2, kernels[2], stride=2, padding=1), get_act_fn(act)(), nn.Conv2d(channel // 2, channel, kernels[3], stride=2, padding=1), get_act_fn(act)(), nn.Conv2d(channel, channel, kernels[4], padding=1), ] elif stride == 6: blocks = [ nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1), get_act_fn(act)(), nn.Conv2d(channel // 2, channel // 2, kernels[1], stride=2, padding=1), get_act_fn(act)(), nn.Conv2d(channel // 2, channel, kernels[2], stride=2, padding=1), get_act_fn(act)(), nn.Conv2d(channel, channel, kernels[3], padding=1), ] elif stride == 4: blocks = [ nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1), get_act_fn(act)(), nn.Conv2d(channel // 2, channel, kernels[1], stride=2, padding=1), get_act_fn(act)(), nn.Conv2d(channel, channel, kernels[2], padding=1), ] elif stride == 2: blocks = [ nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1), get_act_fn(act)(), nn.Conv2d(channel // 2, channel, kernels[1], padding=1), ] for i in range(n_res_block): blocks.append(ResBlock(channel, n_res_channel, res_kernels, act=act)) blocks.append(get_act_fn(act)()) blocks.append(nn.Conv2d(channel, out_channel, 1)) self.blocks = nn.Sequential(*blocks) def forward(self, x): return self.blocks(x) class Decoder(nn.Module): def __init__( self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, kernels, res_kernels, act='relu', padding=1, recon_loss='mse' ): super().__init__() blocks = [nn.Conv2d(in_channel, channel, kernels[0], padding=padding)] for i in range(n_res_block): blocks.append(ResBlock(channel, n_res_channel, res_kernels, act)) blocks.append(get_act_fn(act)()) if stride == 8: blocks.extend( [ nn.ConvTranspose2d(channel, channel // 2, kernels[1], stride=2, padding=1), get_act_fn(act)(), nn.ConvTranspose2d(channel//2, channel // 2, kernels[2], stride=2, padding=1), get_act_fn(act)(), nn.ConvTranspose2d(channel//2, channel // 2, kernels[3], stride=2, padding=1), get_act_fn(act)() ] ) self.conv_out = nn.ConvTranspose2d(channel // 2, out_channel, kernels[4], stride=2, padding=1) elif stride == 6: blocks.extend( [ nn.ConvTranspose2d(channel, channel // 2, kernels[1], stride=2, padding=1), get_act_fn(act)(), nn.ConvTranspose2d(channel//2, channel // 2, kernels[2], stride=2, padding=1), get_act_fn(act)() ] ) self.conv_out = nn.ConvTranspose2d(channel // 2, out_channel, kernels[3], stride=2, padding=1) elif stride == 4: blocks.extend( [ nn.ConvTranspose2d(channel, channel // 2, kernels[1], stride=2, padding=1), get_act_fn(act)() ] ) self.conv_out = nn.ConvTranspose2d(channel // 2, out_channel, kernels[2], stride=2, padding=1) elif stride == 2: self.conv_out = nn.ConvTranspose2d(channel, out_channel, kernels[1], stride=2, padding=1, output_padding=1) self.blocks = nn.Sequential(*blocks) self.recon_loss = recon_loss if self.recon_loss == 'mse': self.final = nn.Sigmoid() def forward(self, x): out = self.conv_out(self.blocks(x)) if self.recon_loss == 'mse': out = self.final(out) return out