import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions.multivariate_normal import MultivariateNormal import math class VectorQuantizer(nn.Module): def __init__(self, num_embeddings, embedding_dim) -> None: super(VectorQuantizer, self).__init__() self.embedding_dim = embedding_dim # length or dim of the embedding vector L self.num_embeddings = num_embeddings # number of embedding vectors K self.embedding = nn.Embedding(num_embeddings, embedding_dim) self.noise_std = 0.05 def forward(self, inputs, noise=None): if noise is not None: self.noise_std = noise # self.noise_std = torch.randint(low=1, high=5, size=(1,), device=inputs.device) inputs = inputs.permute(0, 2, 3, 1).contiguous() # convert the input [B, C, H, W] to [B, H, W, C] input_shape = inputs.shape flat_input = inputs.view(-1, self.embedding_dim) # flatten the inputs [B*W*H, C] distances = (torch.sum(flat_input**2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(flat_input, self.embedding.weight.t())) encodings = F.softmax(-distances, dim=1) if noise is not None: noise_dist = MultivariateNormal(torch.zeros(encodings.shape[-1], device=inputs.device), (self.noise_std**2)*torch.eye(encodings.shape[-1], device=inputs.device)) encodings = encodings + noise_dist.sample_n(encodings.size(0)) quantized = torch.matmul(encodings, self.embedding.weight) # now quantize the vectors by multiplying the one-hot by the embedding weights quantized = quantized.view(input_shape) loss = 1e-3*F.mse_loss(quantized, inputs) # revert back to [B, C, H, W] quantized = quantized.permute(0, 3, 1, 2).contiguous() return loss, quantized class Encoder(nn.Module): def __init__(self, in_channels=3): super(Encoder, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(in_channels, 32, stride=1, kernel_size=3, bias=False, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0), nn.Conv2d(32, 64, stride=1, kernel_size=3, bias=False, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0), nn.Conv2d(64, 128, stride=1, kernel_size=3, bias=False, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0), nn.Conv2d(128, 256, stride=1, kernel_size=3, bias=False, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0) ) def forward(self, inputs): return self.encoder(inputs) class Decoder(nn.Module): def __init__(self, out_channels=3): super(Decoder, self).__init__() self.decoder = nn.Sequential( nn.ConvTranspose2d(256, 128, stride=2, kernel_size=2, padding=0), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 128, stride=1, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.ConvTranspose2d(128, 64, stride=2, kernel_size=2, padding=0), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 64, stride=1, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.ConvTranspose2d(64, 32, stride=2, kernel_size=2, padding=0), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 32, stride=1, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.ConvTranspose2d(32, out_channels, stride=2, kernel_size=2, padding=0), nn.Sigmoid() ) def forward(self, input): return self.decoder(input) class VQVAE(nn.Module): def __init__(self, num_channels, num_embeddings, embedding_dim): super(VQVAE, self).__init__() self.encoder = Encoder(num_channels) self.quantizer = VectorQuantizer(num_embeddings, embedding_dim) self.decoder = Decoder(num_channels) def forward(self, inputs, noise=None): z = self.encoder(inputs) loss, quantized = self.quantizer(z, noise) x_recon = self.decoder(quantized) return loss, x_recon