JSCC / VQ-VAE / model_attn.py
model_attn.py
Raw
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
import math

    
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    # Efficient implementation equivalent to the following:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(-attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value
    
class Attention(nn.Module):
    def __init__(self, num_embeddings) -> None:
        super(Attention, self).__init__()
    def forward(self, query, key, value):
        attn = scaled_dot_product_attention(query, key, value)
        return attn

    
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim) -> None:
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim # length or dim of the embedding vector L
        self.num_embeddings = num_embeddings # number of embedding vectors K
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.attention = Attention(num_embeddings)
        self.noise_std = 0.05
    def forward(self, inputs, noise=None):
        if noise is not None:
            self.noise_std = noise

        # self.noise_std = torch.randint(low=1, high=5, size=(1,), device=inputs.device)

        inputs = inputs.permute(0, 2, 3, 1).contiguous() # convert the input [B, C, H, W] to [B, H, W, C]
        input_shape = inputs.shape

        flat_input = inputs.view(-1, self.embedding_dim) # flatten the inputs [B*W*H, C]

        encodings = self.attention(flat_input, self.embedding.weight, self.embedding.weight)

        if noise is not None:
            noise_dist = MultivariateNormal(torch.zeros(encodings.shape[-1], device=inputs.device), 
                                            (self.noise_std**2)*torch.eye(encodings.shape[-1], device=inputs.device))

            encodings = encodings + noise_dist.sample_n(encodings.size(0))
        # now quantize the vectors by multiplying the one-hot by the embedding weights
        quantized = encodings.view(input_shape)  
        
        # loss function
        loss = 1e-3*F.mse_loss(quantized, inputs)
        # log_quantized = quantized.log()
        # log_quantized = torch.where(torch.isnan(log_quantized), torch.tensor(-1e10), log_quantized)
        # probs_intput = F.softmax(inputs, dim=-1)
        # probs_quantized = F.softmax(quantized, dim=-1)
        # loss = F.kl_div(probs_quantized.log(), inputs, reduction='batchmean')

        # revert back to [B, C, H, W]
        quantized = quantized.permute(0, 3, 1, 2).contiguous() 

        return loss, quantized
    
class Encoder(nn.Module):
    def __init__(self, in_channels=3):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, stride=1, kernel_size=3, bias=False, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(32, 64, stride=1, kernel_size=3, bias=False, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(64, 128, stride=1, kernel_size=3, bias=False, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(128, 256, stride=1, kernel_size=3, bias=False, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
            )    
        
    def forward(self, inputs):
        return self.encoder(inputs)

class Decoder(nn.Module):
    def __init__(self, out_channels=3):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, stride=2, kernel_size=2, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, stride=1, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, stride=2, kernel_size=2, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, stride=1, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, stride=2, kernel_size=2, padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, stride=1, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, out_channels, stride=2, kernel_size=2, padding=0),
            nn.Sigmoid()        
        )
        
    def forward(self, input):
        return self.decoder(input)
    
class VQVAE(nn.Module):
    def __init__(self, num_channels, num_embeddings, embedding_dim):
        super(VQVAE, self).__init__()
        self.encoder = Encoder(num_channels)
        self.quantizer = VectorQuantizer(num_embeddings, embedding_dim)
        self.decoder = Decoder(num_channels)

    def forward(self, inputs, noise=None):
        z = self.encoder(inputs)
        loss, quantized = self.quantizer(z, noise)
        x_recon = self.decoder(quantized)
        return loss, x_recon