mvq / models / conv / conv_bn.py
conv_bn.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import torch
import torch.nn as nn

'''
This version includes batch norm between layers
'''

def get_act_fn(name):
    if name == 'relu':
        fn = torch.nn.ReLU
    elif name == 'gelu':
        fn = torch.nn.GELU
    elif name == 'prelu':
        fn = torch.nn.PReLU
    elif name == 'tanh':
        fn = torch.nn.Tanh
    elif name == 'leakyrelu':
        fn = torch.nn.LeakyReLU
    else:
        raise
    return fn

class ResBlock(nn.Module):
    def __init__(self, in_channel, channel, kernels, act='relu'):
        super().__init__()

        self.conv = nn.Sequential(
            get_act_fn(act)(),
            nn.Conv2d(in_channel, channel, kernels[0], padding=1),
            nn.BatchNorm2d(channel),
            get_act_fn(act)(),
            nn.Conv2d(channel, in_channel, kernels[1]),
            nn.BatchNorm2d(in_channel),
        )

    def forward(self, x):
        out = self.conv(x)
        out += x

        return out

class Encoder(torch.nn.Module):
    def __init__(self, in_channel, out_channel, channel, n_res_block, 
        n_res_channel, stride, kernels, res_kernels, act='leakyrelu'):
        super().__init__()
        self.out_channel = out_channel

        if stride == 10:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1),
                nn.BatchNorm2d(channel // 2),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel // 2, kernels[1], stride=2, padding=1),
                nn.BatchNorm2d(channel // 2),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel // 2, kernels[2], stride=2, padding=1),
                nn.BatchNorm2d(channel // 2),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel// 2, kernels[3], stride=2, padding=1),
                nn.BatchNorm2d(channel// 2),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel, kernels[4], stride=2, padding=1),
                nn.BatchNorm2d(channel),
                get_act_fn(act)(),
                nn.Conv2d(channel, channel, kernels[5], padding=1),
                nn.BatchNorm2d(channel),
            ]
        elif stride == 8:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1),
                nn.BatchNorm2d(channel // 2),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel // 2, kernels[1], stride=2, padding=1),
                nn.BatchNorm2d(channel // 2),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel // 2, kernels[2], stride=2, padding=1),
                nn.BatchNorm2d(channel // 2),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel, kernels[3], stride=2, padding=1),
                nn.BatchNorm2d(channel),
                get_act_fn(act)(),
                nn.Conv2d(channel, channel, kernels[4], padding=1),
                nn.BatchNorm2d(channel),
            ]

        elif stride == 6:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1),
                nn.BatchNorm2d(channel // 2),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel // 2, kernels[1], stride=2, padding=1),
                nn.BatchNorm2d(channel // 2),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel, kernels[2], stride=2, padding=1),
                nn.BatchNorm2d(channel),
                get_act_fn(act)(),
                nn.Conv2d(channel, channel, kernels[3], padding=1),
                nn.BatchNorm2d(channel),
            ]
            
        elif stride == 4:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, kernels[0], stride=2, padding=1),
                nn.BatchNorm2d(channel // 2),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel, kernels[1], stride=2, padding=1),
                nn.BatchNorm2d(channel),
                get_act_fn(act)(),
                nn.Conv2d(channel, channel, kernels[2], padding=1),
                nn.BatchNorm2d(channel),
            ]

        elif stride == 2:
            # if kernels[0] == 3:
            #     padding = 1
            # else:
            #     padding = 0

            blocks = [
                nn.Conv2d(in_channel, channel // 2,  kernels[0], stride=2, padding=1),
                nn.BatchNorm2d(channel // 2),
                get_act_fn(act)(),
                nn.Conv2d(channel // 2, channel, kernels[1], padding=1),
                nn.BatchNorm2d(channel)
            ]

        elif stride == 1:
            blocks = [
                nn.Conv2d(in_channel, channel, 3, padding=1),
                nn.BatchNorm2d(channel)
            ]

        for i in range(n_res_block):
            blocks.append(ResBlock(channel, n_res_channel, res_kernels, act=act))

        
        blocks.append(get_act_fn(act)())
        blocks.append(nn.Conv2d(channel, out_channel, 1))
        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        return self.blocks(x)

    
class Decoder(nn.Module):
    def __init__(
        self, in_channel, out_channel, channel, n_res_block, 
        n_res_channel, stride, kernels, res_kernels, act='relu', 
        padding=1, ablation_prob=0.0
    ):
        super().__init__()
        self.ablation_prob = ablation_prob
        self.in_channel = in_channel

        blocks = [nn.Conv2d(in_channel, channel, kernels[0], padding=padding)]

        for i in range(n_res_block):
            blocks.append(ResBlock(channel, n_res_channel, res_kernels, act))

        blocks.append(get_act_fn(act)())
        if stride == 10:
            blocks.extend(
                [
                    nn.ConvTranspose2d(channel, channel // 2, kernels[1], stride=2, padding=1),
                    nn.BatchNorm2d(channel // 2),
                    get_act_fn(act)(),
                    nn.ConvTranspose2d(channel//2, channel // 2, kernels[2], stride=2, padding=1),
                    nn.BatchNorm2d(channel // 2),
                    get_act_fn(act)(),
                    nn.ConvTranspose2d(channel//2, channel // 2, kernels[3], stride=2, padding=1),
                    nn.BatchNorm2d(channel // 2),
                    get_act_fn(act)(),
                    nn.ConvTranspose2d(channel//2, channel // 2, kernels[4], stride=2, padding=1),
                    nn.BatchNorm2d(channel // 2),
                    get_act_fn(act)()
                ]
            )
            self.conv_out = nn.ConvTranspose2d(channel // 2, out_channel, kernels[5], stride=2, padding=1)

        elif stride == 8:
            blocks.extend(
                [
                    nn.ConvTranspose2d(channel, channel // 2, kernels[1], stride=2, padding=1),
                    nn.BatchNorm2d(channel // 2),
                    get_act_fn(act)(),
                    nn.ConvTranspose2d(channel//2, channel // 2, kernels[2], stride=2, padding=1),
                    nn.BatchNorm2d(channel // 2),
                    get_act_fn(act)(),
                    nn.ConvTranspose2d(channel//2, channel // 2, kernels[3], stride=2, padding=1),
                    nn.BatchNorm2d(channel // 2),
                    get_act_fn(act)()
                ]
            )
            self.conv_out = nn.ConvTranspose2d(channel // 2, out_channel, kernels[4], stride=2, padding=1)

        elif stride == 6:
            blocks.extend(
                [
                    nn.ConvTranspose2d(channel, channel // 2, kernels[1], stride=2, padding=1),
                    nn.BatchNorm2d(channel // 2),
                    get_act_fn(act)(),
                    nn.ConvTranspose2d(channel//2, channel // 2, kernels[2], stride=2, padding=1),
                    nn.BatchNorm2d(channel // 2),
                    get_act_fn(act)()
                ]
            )
            self.conv_out = nn.ConvTranspose2d(channel // 2, out_channel, kernels[3], stride=2, padding=1)


        elif stride == 4:
            blocks.extend(
                [
                    nn.ConvTranspose2d(channel, channel // 2, kernels[1], stride=2, padding=1),
                    nn.BatchNorm2d(channel // 2),
                    get_act_fn(act)()
                ]
            )
            self.conv_out = nn.ConvTranspose2d(channel // 2, out_channel, kernels[2], stride=2, padding=1)

        elif stride == 2:
            self.conv_out = nn.ConvTranspose2d(channel, out_channel, kernels[1], stride=2, padding=1, output_padding=1)
            
        elif stride == 1:
            self.conv_out = nn.ConvTranspose2d(channel, out_channel, kernels[1])


        self.blocks = nn.Sequential(*blocks)
        self.final = nn.Sigmoid()

    def forward(self, x):
        x = self.blocks(x)

        return self.final(self.conv_out(x))