import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions.multivariate_normal import MultivariateNormal import math def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: # Efficient implementation equivalent to the following: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias attn_weight = torch.softmax(-attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value class Attention(nn.Module): def __init__(self, num_embeddings) -> None: super(Attention, self).__init__() def forward(self, query, key, value): attn = scaled_dot_product_attention(query, key, value) return attn class VectorQuantizer(nn.Module): def __init__(self, num_embeddings, embedding_dim) -> None: super(VectorQuantizer, self).__init__() self.embedding_dim = embedding_dim # length or dim of the embedding vector L self.num_embeddings = num_embeddings # number of embedding vectors K self.embedding = nn.Embedding(num_embeddings, embedding_dim) self.attention = Attention(num_embeddings) self.noise_std = 0.05 def forward(self, inputs, noise=None): if noise is not None: self.noise_std = noise # self.noise_std = torch.randint(low=1, high=5, size=(1,), device=inputs.device) inputs = inputs.permute(0, 2, 3, 1).contiguous() # convert the input [B, C, H, W] to [B, H, W, C] input_shape = inputs.shape flat_input = inputs.view(-1, self.embedding_dim) # flatten the inputs [B*W*H, C] encodings = self.attention(flat_input, self.embedding.weight, self.embedding.weight) if noise is not None: noise_dist = MultivariateNormal(torch.zeros(encodings.shape[-1], device=inputs.device), (self.noise_std**2)*torch.eye(encodings.shape[-1], device=inputs.device)) encodings = encodings + noise_dist.sample_n(encodings.size(0)) # now quantize the vectors by multiplying the one-hot by the embedding weights quantized = encodings.view(input_shape) # loss function loss = 1e-3*F.mse_loss(quantized, inputs) # log_quantized = quantized.log() # log_quantized = torch.where(torch.isnan(log_quantized), torch.tensor(-1e10), log_quantized) # probs_intput = F.softmax(inputs, dim=-1) # probs_quantized = F.softmax(quantized, dim=-1) # loss = F.kl_div(probs_quantized.log(), inputs, reduction='batchmean') # revert back to [B, C, H, W] quantized = quantized.permute(0, 3, 1, 2).contiguous() return loss, quantized class Encoder(nn.Module): def __init__(self, in_channels=3): super(Encoder, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(in_channels, 32, stride=1, kernel_size=3, bias=False, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0), nn.Conv2d(32, 64, stride=1, kernel_size=3, bias=False, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0), nn.Conv2d(64, 128, stride=1, kernel_size=3, bias=False, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0), nn.Conv2d(128, 256, stride=1, kernel_size=3, bias=False, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0) ) def forward(self, inputs): return self.encoder(inputs) class Decoder(nn.Module): def __init__(self, out_channels=3): super(Decoder, self).__init__() self.decoder = nn.Sequential( nn.ConvTranspose2d(256, 128, stride=2, kernel_size=2, padding=0), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 128, stride=1, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.ConvTranspose2d(128, 64, stride=2, kernel_size=2, padding=0), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 64, stride=1, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.ConvTranspose2d(64, 32, stride=2, kernel_size=2, padding=0), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 32, stride=1, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.ConvTranspose2d(32, out_channels, stride=2, kernel_size=2, padding=0), nn.Sigmoid() ) def forward(self, input): return self.decoder(input) class VQVAE(nn.Module): def __init__(self, num_channels, num_embeddings, embedding_dim): super(VQVAE, self).__init__() self.encoder = Encoder(num_channels) self.quantizer = VectorQuantizer(num_embeddings, embedding_dim) self.decoder = Decoder(num_channels) def forward(self, inputs, noise=None): z = self.encoder(inputs) loss, quantized = self.quantizer(z, noise) x_recon = self.decoder(quantized) return loss, x_recon