honeyplotnet / models / vq / vq_layer.py
vq_layer.py
Raw
# ---------------------------------------------------------------
# Copyright (c) __________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import torch
import torch.nn.functional as F

class VectorQuantizer(torch.nn.Module):
    def __init__(self, num_embeddings, embedding_dim, beta=0.25, use_fp16=False, **kwargs):
        super().__init__()

        self.name = 'vq'
        self.dtype_float = torch.float32 #torch.float32 if use_fp16 else torch.float32
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.beta = beta
        self.embeddings = torch.nn.Embedding(
            num_embeddings,
            embedding_dim
        )

        torch.nn.init.uniform_(self.embeddings.weight)

    def forward(self, x, is_indices=False, **kwargs):
        if is_indices:
            return self.sample_decoder(x)
        else:
            return self._forward(x)

    def sample_decoder(self, encoding_indices):
        bs, x_dim = encoding_indices.shape
        output_shape = [bs, x_dim, self.embedding_dim]

        flattened = torch.reshape(encoding_indices, [bs, -1])

        encodings = F.one_hot(flattened, self.num_embeddings).to(self.dtype_float).to(encoding_indices.device)
        quantized = torch.matmul(encodings, self.embeddings.weight)
        quantized = torch.reshape(quantized, output_shape)
        return quantized

    def _forward(self, x):

        input_shape = x.shape
        assert input_shape[-1] == self.embedding_dim
        
        flattened = torch.reshape(x, [-1, self.embedding_dim])
        encoding_indices = self.get_code_indices(flattened)

        encodings = F.one_hot(encoding_indices, self.num_embeddings).to(self.dtype_float).to(x.device)

        quantized = torch.matmul(encodings, self.embeddings.weight)
        quantized = torch.reshape(quantized, input_shape)

        commitment_loss = self.beta * torch.mean(
            (quantized.detach() - x) ** 2
        )

        codebook_loss = torch.mean((quantized - x.detach()) ** 2)

        loss = commitment_loss + codebook_loss

        # Straight-through estimator.
        out = x + (quantized - x).detach()

        return out, quantized, loss

    def get_code_indices(self, q_z):
        flattened_inputs = torch.reshape(q_z, [-1, self.embedding_dim])

        similarity = torch.matmul(flattened_inputs, self.embeddings.weight.t())

        s1 = torch.sum(flattened_inputs ** 2, axis=1, keepdims=True)
        s2 = torch.sum(self.embeddings.weight.t() ** 2, axis=0)
        s3 = - 2 * similarity

        distances = s1 +s2 + s3

        # Derive the indices for minimum distances.
        _, encoding_indices = torch.min(distances, axis=1)
        return encoding_indices
    
    def get_embed(self, **kwargs):
        return self.embeddings.weight