Pref-Restoration / tok / ar_dtok / bottleneck.py
bottleneck.py
Raw
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from .. import models
from ..models import register


@register("bottleneck")
class Bottleneck(nn.Module):
    def __init__(
        self,
        bottleneck_dim: int,
        input_dim: int,
        output_dim: int,
        token_nums: int,
        regularizer=None,
        **kwargs
    ):  
        super().__init__()
        self.token_nums = token_nums
        self.input_dim = input_dim
        self.output_dim = output_dim
        if bottleneck_dim > 0:
            self.bottleneck_dim = bottleneck_dim
        else:
            assert self.input_dim == self.output_dim, "input_dim and output_dim must be the same when bottleneck_dim is not specified"
            self.bottleneck_dim = self.input_dim
        
        self.project_dim = self.bottleneck_dim

        if self.bottleneck_dim > 0:
            self.in_linear = nn.Linear(self.input_dim, self.project_dim)
            self.out_linear = nn.Linear(self.bottleneck_dim, self.output_dim)
        else:
            self.in_linear = self.out_linear = lambda x: x
        
        regularizer['args']['dim'] = self.bottleneck_dim
        regularizer['args']['token_nums'] = self.token_nums
        self.regularizer = models.make(regularizer)

    def project_in(self, x):
        assert len(x.shape) == 3, "Input shape must be (batch, n_tokens, e_dim)"
        z = self.in_linear(x)
        return z

    def project_out(self, z_cat):
        z = self.out_linear(z_cat)
        return z

    def decode(self, bottleneck_rep):
        regularized_z = self.regularizer.decode(bottleneck_rep)
        return self.project_out(regularized_z)

    def forward(self, x):  
        z = self.project_in(x)
        projected_z = z
        regularized_output = self.regularizer(z)
        x_hat = self.project_out(regularized_output['regularized_z'])
        bottleneck_rep = regularized_output.pop('bottleneck_rep')
        return {
            'output': x_hat,
            'bottleneck_rep': bottleneck_rep,
            'projected_z': projected_z,
            **regularized_output,
        }


@register("simvq")
class SimVectorQuantizer(nn.Module):
    def __init__(
        self,
        dim,
        codebook_size,
        l2_normalized=False,
        same_index_shape=True,
        stochastic=False,
        stochastic_temperature=1.0,
        **kwargs,
    ):
        super().__init__()
        self.codebook_size = codebook_size
        self.dim = dim
        assert isinstance(l2_normalized, bool)
        self.l2_normalized = l2_normalized
        self.stochastic = stochastic
        self.eval_deterministic = False
        self.default_stochastic_temperature = stochastic_temperature
        
        if self.stochastic:
            if stochastic_temperature > 0: # fixed temperature
                self.stochastic_temperature_inv = 1 / stochastic_temperature
            else: # set stochastic_temperature < 0 to use learnable temperature
                self.stochastic_temperature_inv = nn.Parameter(torch.tensor(10.0))

        # for clear inference code, we remove the codebook init from LLM's embedding
        self.embedding = nn.Embedding(self.codebook_size, self.dim)
        self.embedding_proj = nn.Linear(self.dim, self.dim)

        self.same_index_shape = same_index_shape

    def set_eval_deterministic(self, deterministic=True):
        self.eval_deterministic = deterministic

    def set_stochastic_temperature(self, temperature):
        self.stochastic_temperature_inv = 1 / temperature

    @torch.autocast(device_type='cuda', enabled=False)
    def get_emb(self):
        emb = self.embedding_proj(self.embedding.weight)
        if self.l2_normalized:
            emb = F.normalize(emb, p=2, dim=-1)
        # assert emb.dtype == torch.float32, f"Embedding weight dtype is {emb.dtype}, expected float32"
        return emb

    @torch.autocast(device_type='cuda', enabled=False)
    def forward(self, z):
        emb = self.get_emb()
        z = z.to(emb)
        # z = z.float()
        assert len(z.shape) == 3, "Input shape must be (batch, n_tokens, e_dim)"
        if self.l2_normalized:
            z = F.normalize(z, p=2, dim=-1)

        z_flattened = rearrange(z, 'b n d -> (b n) d')

        if self.stochastic:
            # sample the softmaxed cosine similarity
            assert self.l2_normalized, "Stochastic sampling requires l2 normalization"
            cos_sim = torch.einsum("bd,nd->bn", z_flattened, emb)
            probs = F.softmax(cos_sim * self.stochastic_temperature_inv, dim=-1)
            if self.eval_deterministic and not self.training:
                q_indices = torch.argmax(probs, dim=-1)
            else:
                q_indices = torch.multinomial(probs, 1).squeeze(-1)
        else:
            d = (
                torch.sum(z_flattened**2, dim=1, keepdim=True)
                + torch.sum(emb**2, dim=1)
                - 2
                * torch.einsum(
                    "bd,dn->bn", z_flattened, rearrange(emb, "n d -> d n")
                )
            )
            q_indices = torch.argmin(d, dim=1)

        quantized = F.embedding(q_indices, emb, self.embedding.padding_idx, self.embedding.max_norm,
            self.embedding.norm_type, self.embedding.scale_grad_by_freq, self.embedding.sparse).view(z.shape)  # (b, n, d)
        
        # preserve gradients
        quantized = z + (quantized - z).detach()

        if self.same_index_shape:
            q_indices = q_indices.reshape(quantized.shape[0], quantized.shape[1])

        return_dict = {
            'unregularized_z': z, # but l2 normalized if l2_normalized=True
            'emb': emb, # but l2 normalized if l2_normalized=True
            'regularized_z': quantized,
            'bottleneck_rep': q_indices
        }
        return return_dict
    
    def get_codebook_entry(self, indices, shape=None):
        # shape specifying (batch, height, width, channel)
        indices_shape = indices.shape
        indices_flatten = rearrange(indices, '... -> (...)')

        # get quantized latent vectors
        emb = self.get_emb()
        z_q = F.embedding(indices_flatten, emb)
        # z_q = self.embedding(indices_flatten)
        if self.l2_normalized:
            z_q = F.normalize(z_q, p=2, dim=-1)

        if shape is not None:
            z_q = z_q.reshape(shape)
        else:
            z_q = z_q.reshape([*indices_shape, self.dim])
        return z_q

    def decode(self, indices):
        return self.get_codebook_entry(indices)