JSCC / VQ-VAE / model_soft.py
model_soft.py
Raw
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
import math

    
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim) -> None:
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim # length or dim of the embedding vector L
        self.num_embeddings = num_embeddings # number of embedding vectors K
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.noise_std = 0.05

    def forward(self, inputs, noise=None):
        if noise is not None:
            self.noise_std = noise

        # self.noise_std = torch.randint(low=1, high=5, size=(1,), device=inputs.device)

        inputs = inputs.permute(0, 2, 3, 1).contiguous() # convert the input [B, C, H, W] to [B, H, W, C]
        input_shape = inputs.shape
        flat_input = inputs.view(-1, self.embedding_dim) # flatten the inputs [B*W*H, C]

        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self.embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self.embedding.weight.t()))
        encodings = F.softmax(-distances, dim=1)

        if noise is not None:
            noise_dist = MultivariateNormal(torch.zeros(encodings.shape[-1], device=inputs.device), 
                                            (self.noise_std**2)*torch.eye(encodings.shape[-1], device=inputs.device))

            encodings = encodings + noise_dist.sample_n(encodings.size(0))

        quantized = torch.matmul(encodings, self.embedding.weight)

        # now quantize the vectors by multiplying the one-hot by the embedding weights
        quantized = quantized.view(input_shape)  
        
        loss = 1e-3*F.mse_loss(quantized, inputs)

        # revert back to [B, C, H, W]
        quantized = quantized.permute(0, 3, 1, 2).contiguous() 

        return loss, quantized
    
class Encoder(nn.Module):
    def __init__(self, in_channels=3):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, stride=1, kernel_size=3, bias=False, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(32, 64, stride=1, kernel_size=3, bias=False, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(64, 128, stride=1, kernel_size=3, bias=False, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(128, 256, stride=1, kernel_size=3, bias=False, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
            )    
        
    def forward(self, inputs):
        return self.encoder(inputs)

class Decoder(nn.Module):
    def __init__(self, out_channels=3):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, stride=2, kernel_size=2, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, stride=1, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, stride=2, kernel_size=2, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, stride=1, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, stride=2, kernel_size=2, padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, stride=1, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, out_channels, stride=2, kernel_size=2, padding=0),
            nn.Sigmoid()        
        )
        
    def forward(self, input):
        return self.decoder(input)
    
class VQVAE(nn.Module):
    def __init__(self, num_channels, num_embeddings, embedding_dim):
        super(VQVAE, self).__init__()
        self.encoder = Encoder(num_channels)
        self.quantizer = VectorQuantizer(num_embeddings, embedding_dim)
        self.decoder = Decoder(num_channels)

    def forward(self, inputs, noise=None):
        z = self.encoder(inputs)
        loss, quantized = self.quantizer(z, noise)
        x_recon = self.decoder(quantized)
        return loss, x_recon