JSCC / VAE / model.py
model.py
Raw
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
from torch import nn
import torch.autograd as autograd
import torch.nn.functional as F

class Reshape(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)
    
class STEFunction(autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)


class StraightThroughEstimator(nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
        x = STEFunction.apply(x)
        return x

class EncoderBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim) -> None:
        super(EncoderBlock, self).__init__()
        self.base_encoding = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU()
        )
        self.latent_mean = nn.Linear(hidden_dim, latent_dim)
        self.latent_log_var = nn.Linear(hidden_dim, latent_dim) 

    def forward(self, x):
        x = self.base_encoding(x)
        z_mean = self.latent_mean(x)
        z_log_var = self.latent_log_var(x)
        reparam = self.__reparameterize(z_mean, z_log_var)
        return reparam, z_mean, z_log_var
    
    
    def __reparameterize(self, mu, var):
        batch_size, latent_dim = mu.size(0), mu.size(1)
        eps = torch.randn(batch_size, latent_dim)
        z = mu + torch.exp(0.5*var)*eps
        return z

class DecoderBlock(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim) -> None:
        super(DecoderBlock, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid(),
            Reshape(-1, 1, 28, 28)
        )
    def forward(self, x):
        return self.model(x)

class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim) -> None:
        super(VAE, self).__init__()
        self.encoder = EncoderBlock(input_dim, hidden_dim, latent_dim)
        self.quantize = StraightThroughEstimator()
        self.decoder = DecoderBlock(latent_dim, hidden_dim, input_dim)

    def forward(self, x):
        z, self.mu, self.log_var = self.encoder(x)
        z = self.quantize(z)
        # z = z + torch.randn_like(z, requires_grad=False) # noise
        z = self.decoder(z)
        return z