mvq / models / conv / conv.py
conv.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import torch
import torch.nn as nn

def get_act_fn(name):
    if name == 'relu':
        fn = torch.nn.ReLU
    elif name == 'gelu':
        fn = torch.nn.GELU
    elif name == 'prelu':
        fn = torch.nn.PReLU
    elif name == 'tanh':
        fn = torch.nn.Tanh
    elif name == 'leakyrelu':
        fn = torch.nn.LeakyReLU
    else:
        raise
    return fn

class ResBlock(nn.Module):
    def __init__(self, in_channel, channel, kernels, act='relu'):
        super().__init__()

        self.conv = nn.Sequential(
            get_act_fn(act)(),
            nn.Conv2d(in_channel, channel, kernels[0], padding=1),
            get_act_fn(act)(),
            nn.Conv2d(channel, in_channel, kernels[1]),
        )

    def forward(self, x):
        out = self.conv(x)
        out += x

        return out

class Encoder(torch.nn.Module):
    def __init__(self, in_channel, out_channel, channel, n_res_block, 
        n_res_channel, stride, kernels, res_kernels, act='leakyrelu'):
        super().__init__()
        self.out_channel = out_channel


        if stride == 8:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel // 2, kernels[1], stride=2, padding=1),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel // 2, kernels[2], stride=2, padding=1),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel, kernels[3], stride=2, padding=1),
                get_act_fn(act)(),
                nn.Conv2d(channel, channel, kernels[4], padding=1),
            ]

        elif stride == 6:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel // 2, kernels[1], stride=2, padding=1),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel, kernels[2], stride=2, padding=1),
                get_act_fn(act)(),
                nn.Conv2d(channel, channel, kernels[3], padding=1),
            ]
            
        elif stride == 4:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel, kernels[1], stride=2, padding=1),
                get_act_fn(act)(),
                nn.Conv2d(channel, channel, kernels[2], padding=1),
            ]

        elif stride == 2:
            blocks = [
                nn.Conv2d(in_channel, channel // 2,  kernels[0], stride=2, padding=1),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel, kernels[1], padding=1),
            ]

        for i in range(n_res_block):
            blocks.append(ResBlock(channel, n_res_channel, res_kernels, act=act))

        blocks.append(get_act_fn(act)())

        blocks.append(nn.Conv2d(channel, out_channel, 1))

        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        return self.blocks(x)
    
class Decoder(nn.Module):
    def __init__(
        self, in_channel, out_channel, channel, n_res_block, 
        n_res_channel, stride, kernels, res_kernels, act='relu', 
        padding=1, recon_loss='mse'
    ):
        super().__init__()

        blocks = [nn.Conv2d(in_channel, channel, kernels[0], padding=padding)]

        for i in range(n_res_block):
            blocks.append(ResBlock(channel, n_res_channel, res_kernels, act))

        blocks.append(get_act_fn(act)())

        if stride == 8:
            blocks.extend(
                [
                    nn.ConvTranspose2d(channel, channel // 2, kernels[1], stride=2, padding=1),
                    get_act_fn(act)(),
                    nn.ConvTranspose2d(channel//2, channel // 2, kernels[2], stride=2, padding=1),
                    get_act_fn(act)(),
                    nn.ConvTranspose2d(channel//2, channel // 2, kernels[3], stride=2, padding=1),
                    get_act_fn(act)()
                ]
            )
            self.conv_out = nn.ConvTranspose2d(channel // 2, out_channel, kernels[4], stride=2, padding=1)

        elif stride == 6:
            blocks.extend(
                [
                    nn.ConvTranspose2d(channel, channel // 2, kernels[1], stride=2, padding=1),
                    get_act_fn(act)(),
                    nn.ConvTranspose2d(channel//2, channel // 2, kernels[2], stride=2, padding=1),
                    get_act_fn(act)()
                ]
            )
            self.conv_out = nn.ConvTranspose2d(channel // 2, out_channel, kernels[3], stride=2, padding=1)


        elif stride == 4:
            blocks.extend(
                [
                    nn.ConvTranspose2d(channel, channel // 2, kernels[1], stride=2, padding=1),
                    get_act_fn(act)()
                ]
            )
            self.conv_out = nn.ConvTranspose2d(channel // 2, out_channel, kernels[2], stride=2, padding=1)

        elif stride == 2:
            self.conv_out = nn.ConvTranspose2d(channel, out_channel, kernels[1], stride=2, padding=1, output_padding=1)

        self.blocks = nn.Sequential(*blocks)

        self.recon_loss = recon_loss
        if self.recon_loss == 'mse':
            self.final = nn.Sigmoid()
        

    def forward(self, x):
        out = self.conv_out(self.blocks(x))
        if self.recon_loss == 'mse':
            out = self.final(out)
        return out