honeyplotnet / models / conv / conv.py
conv.py
Raw
# ---------------------------------------------------------------
# Copyright (c) __________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------


import torch
import torch.nn as nn

def get_act_fn(name):
    if name == 'relu':
        fn = torch.nn.ReLU
    elif name == 'gelu':
        fn = torch.nn.GELU
    elif name == 'prelu':
        fn = torch.nn.PReLU
    elif name == 'tanh':
        fn = torch.nn.Tanh
    elif name == 'leakyrelu':
        fn = torch.nn.LeakyReLU
    else:
        raise
    return fn

class ResBlock(nn.Module):
    def __init__(self, in_channel, channel, kernels, padding=0, act='relu'):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv1d(in_channel, channel, kernels[0], padding=padding),
            nn.BatchNorm1d(channel),
            get_act_fn(act)(),
            nn.Conv1d(channel, in_channel, kernels[1]),
            nn.BatchNorm1d(in_channel),
            get_act_fn(act)(),
        )

    def forward(self, x):
        out = self.conv(x)
        out = out + x
        return out

class Conv1dEncoder(nn.Module):
    def __init__(self, channels, kernels, stride, padding, 
        n_res_block, n_res_channel, res_kernels, res_padding, 
        act='leakyrelu', use_bn=False, use_proj=False):
        super().__init__()
        assert isinstance(channels, list), "provide list for channels"
        assert len(res_kernels) == 2

        self.nblocks = len(channels)
        blocks = []
        prev_chn = channels[0]
        kidx = 0
        if use_proj:
            out_channel    = channels[-1]
            block_channels = channels[1:-1]
        else:
            block_channels = channels[1:]
        for chn in block_channels:
            new_block = [nn.Conv1d(
                prev_chn, chn, kernel_size=kernels[kidx], stride=stride, padding=padding)]
            
            if use_bn:
                new_block += [nn.BatchNorm1d(chn)]

            new_block += [get_act_fn(act)()]
            blocks += new_block
            kidx += 1

            prev_chn = chn

        for _ in range(n_res_block):
            blocks.append(ResBlock(prev_chn, n_res_channel, res_kernels, padding=res_padding, act=act))
        
        if use_proj:
            blocks.append(nn.Conv1d(prev_chn, out_channel, 1))

        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        return self.blocks(x)


class Conv1dDecoder(nn.Module):
    def __init__(self, channels, kernels, stride, padding, 
        n_res_block, n_res_channel, res_kernels, res_padding, 
        act='relu', use_bn=False, use_proj=False):
        super().__init__()
        assert isinstance(channels, list), "provide list for channels"
        assert len(res_kernels) == 2

        self.nblocks = len(channels)
        prev_chn = channels[0]
        blocks = []
        kidx = 0

        for _ in range(n_res_block):
            blocks.append(ResBlock(prev_chn, n_res_channel, res_kernels, padding=res_padding, act=act))
        
        if use_proj:
            out_channel    = channels[-1]
            block_channels = channels[1:-1]
        else:
            block_channels = channels[1:]

        for chn in block_channels:
            new_block = [nn.ConvTranspose1d(
                prev_chn, chn, kernel_size=kernels[kidx], stride=stride, padding=padding)]
            
            if use_bn:
                new_block += [nn.BatchNorm1d(chn)]

            new_block += [get_act_fn(act)()]
            blocks += new_block
            kidx += 1

            prev_chn = chn

        if use_proj:
            blocks.append(nn.Conv1d(prev_chn, out_channel, 1))

        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        return self.blocks(x)