import torch from torch.distributions.multivariate_normal import MultivariateNormal from torch import nn import torch.autograd as autograd import torch.nn.functional as F class Reshape(nn.Module): def __init__(self, *args): super().__init__() self.shape = args def forward(self, x): return x.view(self.shape) class STEFunction(autograd.Function): @staticmethod def forward(ctx, input): return (input > 0).float() @staticmethod def backward(ctx, grad_output): return F.hardtanh(grad_output) class StraightThroughEstimator(nn.Module): def __init__(self): super(StraightThroughEstimator, self).__init__() def forward(self, x): x = STEFunction.apply(x) return x class EncoderBlock(nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim) -> None: super(EncoderBlock, self).__init__() self.base_encoding = nn.Sequential( nn.Flatten(), nn.Linear(input_dim, hidden_dim), nn.ReLU() ) self.latent_mean = nn.Linear(hidden_dim, latent_dim) self.latent_log_var = nn.Linear(hidden_dim, latent_dim) def forward(self, x): x = self.base_encoding(x) z_mean = self.latent_mean(x) z_log_var = self.latent_log_var(x) reparam = self.__reparameterize(z_mean, z_log_var) return reparam, z_mean, z_log_var def __reparameterize(self, mu, var): batch_size, latent_dim = mu.size(0), mu.size(1) eps = torch.randn(batch_size, latent_dim) z = mu + torch.exp(0.5*var)*eps return z class DecoderBlock(nn.Module): def __init__(self, latent_dim, hidden_dim, output_dim) -> None: super(DecoderBlock, self).__init__() self.model = nn.Sequential( nn.Linear(latent_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim), nn.Sigmoid(), Reshape(-1, 1, 28, 28) ) def forward(self, x): return self.model(x) class VAE(nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim) -> None: super(VAE, self).__init__() self.encoder = EncoderBlock(input_dim, hidden_dim, latent_dim) self.quantize = StraightThroughEstimator() self.decoder = DecoderBlock(latent_dim, hidden_dim, input_dim) def forward(self, x): z, self.mu, self.log_var = self.encoder(x) z = self.quantize(z) # z = z + torch.randn_like(z, requires_grad=False) # noise z = self.decoder(z) return z