Pref-Restoration / tok / ar_dtok / ar_model.py
ar_model.py
Raw
import os
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

from .. import models
from .generate import generate as ar_generate


def find_multiple(n: int, k: int):
    if n % k == 0:
        return n
    return n + k - (n % k)


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, scale_factor=10000):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    scale_factor: the base for the scaling factor, default is 10000
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.
    omega = 1. / scale_factor**omega  # Parameterized scaling factor (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


@dataclass
class ModelArgs:
    dim: int = 4096
    n_layer: int = 32
    n_head: int = 32

    n_kv_head: Optional[int] = None
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    rope_base: float = 10000
    norm_eps: float = 1e-5
    initializer_range: float = 0.02
    
    token_dropout_p: float = 0.1
    attn_dropout_p: float = 0.0
    resid_dropout_p: float = 0.1
    ffn_dropout_p: float = 0.1
    drop_path_rate: float = 0.0

    num_classes: int = 1000
    class_dropout_prob: float = 0.1
    model_type: str = 'class_cond' # clip_cond, indice_cond
    cond_dim: int = 1152
    cond_vocab_size: int = 8192

    vocab_size: int = 8192
    cls_token_num: int = 1

    max_batch_size: int = 32
    max_seq_len: int = 2048

    use_fixed_pe: bool = False

    frame_prediction: bool = False


class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    @torch.autocast(device_type='cuda', enabled=False)
    def _norm(self, x):
        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
        self.act = nn.GELU(approximate='tanh')
        self.fc2 = nn.Linear(hidden_features, out_features, bias=False)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x


#################################################################################
#                            Drop Path Implementation                           #
#################################################################################

def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor


class DropPath(torch.nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f'drop_prob={round(self.drop_prob,3):0.3f}'


#################################################################################
#                                   AR Model                                    #
#################################################################################

class FeedForward(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        hidden_dim = 4 * config.dim
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if config.ffn_dim_multiplier is not None:
            hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
        hidden_dim = find_multiple(hidden_dim, config.multiple_of)

        self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
        self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
        self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)

    def forward(self, x):
        return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
    

class KVCache(nn.Module):
    def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
        super().__init__()
        cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
        self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))

    def update(self, input_pos, k_val, v_val):
        # input_pos: [S], k_val: [B, H, S, D]
        assert input_pos.shape[0] == k_val.shape[2], f"{input_pos.shape[0]} != {k_val.shape[2]}"
        k_out = self.k_cache
        v_out = self.v_cache
        k_out[:, :, input_pos] = k_val.to(k_out.dtype)
        v_out[:, :, input_pos] = v_val.to(v_out.dtype)

        return k_out, v_out


class Attention(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        assert config.dim % config.n_head == 0
        self.dim = config.dim
        self.head_dim = config.dim // config.n_head
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
        total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim

        # key, query, value projections for all heads, but in a batch
        self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
        self.wo = nn.Linear(config.dim, config.dim, bias=False)
        self.kv_cache = None

        # regularization
        self.attn_dropout_p = config.attn_dropout_p
        self.resid_dropout = nn.Dropout(config.resid_dropout_p)

    def forward(
        self, x: torch.Tensor,
        input_pos: Optional[torch.Tensor] = None, 
        mask: Optional[torch.Tensor] = None
    ):
        bsz, seqlen, _ = x.shape
        kv_size = self.n_kv_head * self.head_dim
        xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

        xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)
        
        xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))

        if self.kv_cache is not None:
            keys, values = self.kv_cache.update(input_pos, xk, xv)
        else:
            keys, values = xk, xv
        keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
        values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1)

        output = F.scaled_dot_product_attention(
            xq, keys, values, 
            attn_mask=mask, 
            is_causal=True if mask is None else False, # is_causal=False is for KV cache
            dropout_p=self.attn_dropout_p if self.training else 0)            
        
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

        output = self.resid_dropout(self.wo(output))
        return output


class TransformerBlock(nn.Module):
    def __init__(self, config: ModelArgs, drop_path: float):
        super().__init__()
        self.attention = Attention(config)
        self.feed_forward = FeedForward(config)
        self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(
        self, x: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
        h = x + self.drop_path(self.attention(self.attention_norm(x), start_pos, mask))
        out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
        return out


class LabelEmbedder(nn.Module):
    """
    Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
    """
    def __init__(self, num_classes, hidden_size, dropout_prob):
        super().__init__()
        use_cfg_embedding = dropout_prob > 0
        self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
        self.num_classes = num_classes
        self.dropout_prob = dropout_prob

    def token_drop(self, labels, force_drop_ids=None):
        """
        Drops labels to enable classifier-free guidance.
        """
        if force_drop_ids is None:
            drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
        else:
            drop_ids = force_drop_ids == 1
        labels = torch.where(drop_ids, self.num_classes, labels)
        return labels

    def forward(self, labels, train, force_drop_ids=None):
        use_dropout = self.dropout_prob > 0
        if (train and use_dropout) or (force_drop_ids is not None):
            labels = self.token_drop(labels, force_drop_ids)

        # replace all negative labels with the last class (unconditional class)
        labels = torch.where(labels < 0, self.num_classes, labels)
        embeddings = self.embedding_table(labels)
        return embeddings


class ARModel(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size
        self.n_layer = config.n_layer
        self.max_seq_length = config.max_seq_len
        self.num_classes = config.num_classes
        self.model_type = config.model_type
        self.cls_token_num = config.cls_token_num
        self.is_sampling = False
        self.frame_prediction = config.frame_prediction

        if self.model_type == 'class_cond':
            self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
        elif self.model_type == 'clip_cond':
            self.clip_proj = nn.Linear(config.cond_dim, config.dim)
        elif self.model_type == 'indice_cond':
            self.clip_proj = LabelEmbedder(config.cond_vocab_size + 1, config.dim, 0.0)
        else:
            raise Exception("please check model type")
        
        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
        self.tok_dropout = nn.Dropout(config.token_dropout_p)

        # transformer blocks
        dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)]
        self.layers = torch.nn.ModuleList()
        for layer_id in range(config.n_layer):
            self.layers.append(TransformerBlock(config, dpr[layer_id]))

        # output layer
        self.norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

        if config.use_fixed_pe:
            self.register_buffer('abs_pe', torch.zeros(1, config.max_seq_len + config.cls_token_num - 1, config.dim))
            abs_pe = get_1d_sincos_pos_embed_from_grid(embed_dim=config.dim, pos=np.arange(config.max_seq_len + config.cls_token_num - 1))
            self.abs_pe.copy_(torch.from_numpy(abs_pe).float().reshape_as(self.abs_pe))
            print(f"Using fixed absolute PE")
        else:
            self.abs_pe = nn.Parameter(torch.randn(1, config.max_seq_len + config.cls_token_num - 1, config.dim) * 0.02)
            print(f"Using learned absolute PE")

        self.initialize_weights()

    def initialize_weights(self):        
        # Initialize nn.Linear and nn.Embedding
        self.apply(self._init_weights)

        # Zero-out output layers:
        if hasattr(self.output, 'weight') and isinstance(self.output.weight, nn.Parameter):
            nn.init.constant_(self.output.weight, 0)

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)


    @property
    def device(self):
        return next(self.parameters()).device
    
    @property
    def dtype(self):
        return next(self.parameters()).dtype
    

    @contextmanager
    def sampling(self):
        self.is_sampling = True
        try:
            yield
        finally:
            self.is_sampling = False


    def setup_caches(self, max_batch_size, max_seq_length, dtype):
        assert max_seq_length == self.max_seq_length + self.cls_token_num, f'{max_seq_length} != {self.max_seq_length} + {self.cls_token_num=}'

        head_dim = self.config.dim // self.config.n_head
        max_seq_length = find_multiple(max_seq_length, 8)

        for b in self.layers:
            b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype)

        causal_mask = torch.tril(torch.ones(max_seq_length, max_seq_length, dtype=torch.bool))
        self.causal_mask = causal_mask.unsqueeze(0).repeat(max_batch_size, 1, 1)


    def reset_caches(self):
        for b in self.layers:
            b.attention.kv_cache = None

    def clip_embedding(self, x):
        if self.model_type == 'clip_cond':
            if self.training and self.config.class_dropout_prob > 0:
                drop_ids = torch.rand(x.shape[0], device=x.device) < self.config.class_dropout_prob
                x[drop_ids] = 0.
            x = self.clip_proj(x.to(self.dtype)) # Linear
        elif self.model_type == 'indice_cond':
            if self.training and self.config.class_dropout_prob > 0:
                drop_ids = torch.rand(x.shape[0], device=x.device) < self.config.class_dropout_prob
                x[drop_ids] = self.config.cond_vocab_size
            x = self.clip_proj(x, train=self.training) # Embedding
        return x

    def forward(
        self, 
        idx: Optional[torch.Tensor], # (b, n)
        cond_idx: Optional[torch.Tensor],  # cond_idx_or_embed
        input_pos:  Optional[torch.Tensor] = None, 
        targets: Optional[torch.Tensor] = None,
        mask: Optional[torch.Tensor] = None,
        valid: Optional[torch.Tensor] = None,
    ):
        if idx is not None and cond_idx is not None: # training or naive inference
            if self.model_type == 'class_cond':
                cond_embeddings = self.cls_embedding(cond_idx, train=self.training).unsqueeze(1)[:,:self.cls_token_num]
            elif self.model_type in ['clip_cond', 'indice_cond']:
                cond_embeddings = self.clip_embedding(cond_idx)
            token_embeddings = self.tok_embeddings(idx) # (b, n, d)
            token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)  # (b, cls_token_num + n, d)
            h = self.tok_dropout(token_embeddings)
        else:
            if cond_idx is not None: # prefill in inference
                if self.model_type == 'class_cond':
                    token_embeddings = self.cls_embedding(cond_idx, train=self.training).unsqueeze(1)[:,:self.cls_token_num]
                elif self.model_type in ['clip_cond', 'indice_cond']:
                    token_embeddings = self.clip_embedding(cond_idx)
            else: # decode_n_tokens(kv cache) in inference
                token_embeddings = self.tok_embeddings(idx)
            
            bs = token_embeddings.shape[0]
            mask = self.causal_mask[:bs, None, input_pos]
            h = self.tok_dropout(token_embeddings)
        
        if self.is_sampling:
            h = h + self.abs_pe[:, input_pos]
        else:
            h = h + self.abs_pe[:, :h.shape[1]]
        
        # transformer blocks
        for layer in self.layers:
            h = layer(h, input_pos, mask)
        
        # output layers
        h = self.norm(h)
        logits = self.output(h)
        # if self.training or self.is_sampling:
        if cond_idx is not None:
        # if self.training:
            # logits = logits[:, self.cls_token_num - 1:].contiguous()
            logits = logits[:, cond_idx.size(1) - 1:].contiguous()

        # if we are given some desired targets also calculate the loss
        loss = None
        if valid is not None:
            loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
            valid_all = valid[:,None].repeat(1, targets.shape[1]).view(-1)
            loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1)
        elif targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    
    @torch.inference_mode()
    def sample(
        self, 
        c,
        cfg_scale=2.0,
        cfg_interval=-1,
        temperature=1.0,
        top_k=0,
        top_p=1.0,
        seq_length=None,
    ):
        seq_length = self.max_seq_length if seq_length is None else seq_length     
        with self.sampling():
            sampled_seqs = ar_generate(
                self, c, seq_length,
                cfg_scale=cfg_scale, cfg_interval=cfg_interval,
                temperature=temperature, top_k=top_k,
                top_p=top_p, sample_logits=True, 
            )   
        return sampled_seqs
    

    @classmethod
    def from_checkpoint(cls, ckpt, load_state_dict=True):
        if isinstance(ckpt, str):
            assert os.path.exists(ckpt), f"checkpoint {ckpt} does not exist"
            ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage)
        else:
            assert isinstance(
                ckpt, dict
            ), f"checkpoint must be a dict or a path to a checkpoint"
        model = models.make(ckpt["model"], load_sd=load_state_dict)
        return model


#################################################################################
#                             LLAMA-ABS Configs                                 #
#################################################################################

def LLAMA_ABS_XXXL(**kwargs):
    return ARModel(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B

def LLAMA_ABS_XXL(**kwargs):
    return ARModel(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B

def LLAMA_ABS_XL(**kwargs):
    return ARModel(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M

def LLAMA_ABS_LP(**kwargs):
    return ARModel(ModelArgs(n_layer=30, n_head=20, dim=1280, **kwargs)) # 632M

def LLAMA_ABS_L(**kwargs):
    return ARModel(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M

def LLAMA_ABS_B(**kwargs):
    return ARModel(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M

def LLAMA_ABS_S(**kwargs):
    return ARModel(ModelArgs(n_layer=12, n_head=6, dim=384, **kwargs)) # 21.7M

ar_models = {
    'llama-abs-S': LLAMA_ABS_S,
    'llama-abs-B': LLAMA_ABS_B,
    'llama-abs-L': LLAMA_ABS_L,
    'llama-abs-LP': LLAMA_ABS_LP,
    'llama-abs-XL': LLAMA_ABS_XL,
    'llama-abs-XXL': LLAMA_ABS_XXL,
    'llama-abs-XXXL': LLAMA_ABS_XXXL,
}

models.models.update(ar_models)