mvq / models / mh_dropout / mhd_helper.py
mhd_helper.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import torch
import torch.nn as nn

def get_topk_batch(topk_idx, x):
    x_batch = []
    for idx, _x  in zip(topk_idx, x):
        x_batch.append(_x[idx])
    return torch.stack(x_batch, dim=0)


def unflat_tensor(x, hypothese_count):
    _, channel_dim, x_dim, y_dim = x.shape
    return x.reshape((-1, hypothese_count, channel_dim, x_dim, y_dim))

def get_reduce_fn(name):
    if name == 'mean':
        fn = torch.mean
    elif name == 'sum':
        fn = torch.sum
    else:
        raise
    return fn

def get_act_fn(name):
    if name == 'relu':
        fn = nn.ReLU
    elif name == 'gelu':
        fn = nn.GELU
    elif name == 'prelu':
        fn = nn.PReLU
    elif name == 'tanh':
        fn = nn.Tanh
    elif name == 'leakyrelu':
        fn = nn.LeakyReLU
    else:
        raise

    return fn

def get_dist_loss(name):
    if name == 'mse':
        fn = nn.MSELoss
    elif name == 'smoothl1':
        fn = nn.SmoothL1Loss
    else:
        raise
    return fn

class ResBlock(nn.Module):
    def __init__(self, in_channel, channel, kernels, act='relu'):
        super().__init__()

        self.conv = nn.Sequential(
            get_act_fn(act)(),
            nn.Conv2d(in_channel, channel, kernels[0], padding=1),
            nn.BatchNorm2d(channel),
            get_act_fn(act)(),
            nn.Conv2d(channel, in_channel, kernels[1]),
            nn.BatchNorm2d(in_channel),
        )


    def forward(self, x):
        out = self.conv(x)
        out += x

        return out

class MLP2d(nn.Module):
    def __init__(self, inp_dim, hid_dim, out_dim, act, kernels=[3,1], padding=1, res=False, n_res_block=0, 
        res_kernels=[3,1], n_res_channel=64, **kwargs):
        super().__init__()
        self.inp_dim = inp_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim
        self.res = res

        if self.inp_dim != self.out_dim and self.res:
            self.pre = nn.Conv2d(inp_dim, out_dim, 1)
        
        self.f = nn.Sequential(
            nn.Conv2d(inp_dim, hid_dim, kernels[0], padding=padding),
            get_act_fn(act)(),
            nn.Conv2d(hid_dim, out_dim, kernels[1])
        )
        
        self.n_res_block = n_res_block

        blocks = []
        for _ in range(n_res_block):
            blocks.append(ResBlock(out_dim, n_res_channel, res_kernels, act=act))
        
        if n_res_block > 0:
            blocks.append(get_act_fn(act)())
            blocks.append(nn.Conv2d(out_dim, out_dim, 1))
            self.blocks = nn.Sequential(*blocks)

    def forward(self, x):

        out = self.f(x)
        if self.res:
            if self.inp_dim != self.out_dim:
                x = self.pre(x)
            out += x

        if self.n_res_block > 0:
            out = self.blocks(out)
            
        return out 


class MLP1d(nn.Module):
    def __init__(self, inp_dim, hid_dim, out_dim, act, res=False, **kwargs):
        super().__init__()
        self.inp_dim = inp_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim
        self.res = res

        if self.inp_dim != self.out_dim and self.res:
            self.pre = nn.Linear(inp_dim, out_dim)

        self.f = nn.Sequential(
            nn.Linear(inp_dim, hid_dim),
            get_act_fn(act)(),
            nn.Linear(hid_dim, out_dim)
        )

    def forward(self, x):
        out = self.f(x)
        if self.res:
            if self.inp_dim != self.out_dim:
                x = self.pre(x)
            out += x

        return out 

class UpSampler2d(nn.Module):
    def __init__(self, inp_dim, out_dim, up_sample_ratio, act='relu', **kwargs):
        super().__init__()
        self.up_sample_ratio = up_sample_ratio
        layers = []

        if up_sample_ratio <= 3:
            layers.append(torch.nn.ConvTranspose2d(inp_dim, out_dim, up_sample_ratio))
        else:
        
            ratio = 2 if up_sample_ratio % 2 == 0 else 3
            layers.append(torch.nn.ConvTranspose2d(inp_dim, out_dim, ratio))
            layers.append(get_act_fn(act)())
            up_sample_ratio -= ratio

            while up_sample_ratio > 0:
                ratio = 3 if up_sample_ratio == 3 else 2
                layers.append(torch.nn.ConvTranspose2d(out_dim, out_dim, ratio))
                layers.append(get_act_fn(act)())
                up_sample_ratio -= (ratio - 1)

        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layers(x)