mini-llama / src / llama.py
llama.py
Raw
from contextlib import nullcontext
from typing import Optional, Tuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from base_llama import LlamaPreTrainedModel, LlamaConfig
from rope import apply_rotary_emb
from utils import *


# Root Mean Square Layer Normalization (https://arxiv.org/abs/1910.07467)
# borrowed from the official Llama implementation:
# https://github.com/facebookresearch/llama/blob/main/llama/model.py
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        """
        Initialize the RMSNorm normalization layer.

        Args:
            dim (int): The dimension of the input tensor.
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

        Attributes:
            eps (float): A small value added to the denominator for numerical stability.
            weight (nn.Parameter): Learnable scaling parameter.

        """
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor):
        """
        Compute the root mean square normalization. Use Equation 4 under
        Section 4 of https://arxiv.org/abs/1910.07467 as a reference. Add
        the given epsilon value (self.eps) to the tensor's norm (i.e. inside
        the square root in Equation 4) before normalizing the tensor.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The normalized tensor.
        """
        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        """
        Apply the root mean square normalizer.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor after applying RMSNorm.

        """
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


class Attention(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.n_kv_heads = (
            config.n_heads if config.n_kv_heads is None else config.n_kv_heads
        )
        assert config.n_heads % self.n_kv_heads == 0
        model_parallel_size = 1
        self.n_local_heads = config.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = config.dim // config.n_heads
        self.max_seq_len = config.max_seq_len
        self.compute_query = nn.Linear(
            config.dim, config.n_heads * self.head_dim, bias=False
        )
        self.compute_key = nn.Linear(
            config.dim, self.n_kv_heads * self.head_dim, bias=False
        )
        self.compute_value = nn.Linear(
            config.dim, self.n_kv_heads * self.head_dim, bias=False
        )
        self.compute_output = nn.Linear(
            config.n_heads * self.head_dim, config.dim, bias=False
        )
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.dropout = config.dropout

    def compute_query_key_value_scores(
        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
    ) -> torch.Tensor:
        """
        Jointly compute Scaled Dot Product Attention (see Section 3.2.1 in
        https://arxiv.org/abs/1706.03762 for details). The query, key, and
        value tensors each have shape (bs, n_local_heads, seqlen, head_dim).
        An optimal implemention will jointly computing attention for multiple
        heads (n_local_heads of them) at once using matrix/tensor operations.

        Make sure to use attention_dropout (self.attn_dropout) on the computed
        attention matrix before applying it to the value tensor.
        """
        dim_query = query.size(-1)

        # scores shape = (bs, n_local_heads, seqlen, seqlen)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(dim_query)
        softmax_scores = scores.softmax(dim=-1)  # Same shape as scores.

        # Returned attention shape = (bs, n_local_heads, seqlen, head_dim)
        # For each batch and head, each row has the attention-weighted sum
        # of values for a position in the sequence.
        return torch.matmul(self.attn_dropout(softmax_scores), value)

    def forward(self, x: torch.Tensor):
        """
        Llama2 uses Grouped-Query Attention. The details of GQA are actually
        not critical to solving this assignment; you are simply asked to
        compute Scaled Dot Product Attention (see above for details). GQA is
        a memory optimization to compute multi-head attention efficiently. See
        Section 2.2 in https://arxiv.org/abs/2305.13245 or
        https://ai.plainenglish.io/understanding-llama2-kv-cache-grouped-query-attention-rotary-embedding-and-more-c17e5f49a6d7
        for details.
        """
        batch_size, seqlen, _ = x.shape

        query = self.compute_query(x)
        key = self.compute_key(x)
        value = self.compute_value(x)
        query = query.view(batch_size, seqlen, self.n_local_heads, self.head_dim)
        key = key.view(batch_size, seqlen, self.n_local_kv_heads, self.head_dim)
        value = value.view(batch_size, seqlen, self.n_local_kv_heads, self.head_dim)

        # RoPE relative positional embeddings
        query, key = apply_rotary_emb(query, key, self.head_dim, self.max_seq_len)

        # Grouped multiquery attention: expand out keys and values.
        # Convert both to:
        # (bs, seqlen, n_local_heads, head_dim)
        key = torch.repeat_interleave(key, dim=2, repeats=self.n_rep)
        value = torch.repeat_interleave(value, dim=2, repeats=self.n_rep)

        # make heads into a batch dimension
        query = query.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)
        output = self.compute_query_key_value_scores(query, key, value)

        # restore time as batch dimension and concat heads
        output = output.transpose(1, 2).contiguous().view(batch_size, seqlen, -1)

        # final projection into the residual stream
        output = self.resid_dropout(self.compute_output(output))
        return output


class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = 4 * dim
            hidden_dim = int(2 * hidden_dim / 3)
            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def SwiGLU(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute the SwiGLU activation function (see Section 2 in
        https://arxiv.org/abs/2204.02311
        """
        return F.silu(self.w1(x)) * self.w3(x)

    def forward(self, x):
        return self.dropout(self.w2(self.SwiGLU(x)))


class LlamaLayer(nn.Module):
    def __init__(self, layer_id: int, config: LlamaConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.dim = config.dim
        self.head_dim = config.dim // config.n_heads
        self.attention = Attention(config)
        self.feed_forward = FeedForward(
            dim=config.dim,
            hidden_dim=config.hidden_dim,
            multiple_of=config.multiple_of,
            dropout=config.dropout,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(config.dim, eps=config.layer_norm_eps)
        self.ffn_norm = RMSNorm(config.dim, eps=config.layer_norm_eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        This is the forward pass of the basic transformer building block. This is a
        modernized version of the block shown on the left of Figure 1 on
        https://arxiv.org/pdf/1706.03762.pdf.

        The transformer block should consist of:
        1) layer normalization of the input (via Root Mean Square layer normalization)
        2) self-attention on the layer-normalized input
        3) a residual connection (i.e., add the input to the output of the self-attention)
        4) layer normalization on the output of the self-attention
        5) a feed-forward network on the layer-normalized output of the self-attention
        6) add a residual connection from the unnormalized self-attention output to the
           output of the feed-forward network
        """

        # x has shape: (bs, seqlen, dim)

        #
        # Short version:
        #
        x_norm_attention_res = x + self.attention(
            self.attention_norm(x)
        )  # shape: (bs, seqlen, dim)
        return x_norm_attention_res + self.feed_forward(
            self.ffn_norm(x_norm_attention_res)
        )  # shape: (bs, seqlen, dim)

        #
        # Long version:
        #

        # # 1) Layer normalization of the input (via RMSNorm, self.attention_norm).
        # x_norm = self.attention_norm(x) # shape: (bs, seqlen, dim)

        # # 2) Self-attention on the layer-normalized input.
        # x_norm_attention = self.attention(x_norm) # shape: (bs, seqlen, dim)

        # # 3) Residual connection with input.
        # x_norm_attention_res = x + x_norm_attention # shape: (bs, seqlen, dim)

        # # 4) Layer normalization on the output of the self-attention.
        # x_norm_attention_res_norm = self.ffn_norm(x_norm_attention_res) # shape: (bs, seqlen, dim)

        # # 5) Feed-forward network on the layer-normalized output of the self-attention.
        # x_ffn = self.feed_forward(x_norm_attention_res_norm)

        # # 6) Residual connection from the unnormalized self-attention output
        # #    to the output of the feed-forward network.
        # return x_norm_attention_res + x_ffn # shape: (bs, seqlen, dim)


class Llama(LlamaPreTrainedModel):
    def __init__(self, config: LlamaConfig):
        """
        You will probably never need to call this function, unless you decide
        to pretrain a Llama model from scratch.
        """
        super().__init__(config)
        self.params = config
        self.vocab_size = config.vocab_size
        self.n_layers = config.n_layers

        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
        self.dropout = nn.Dropout(config.dropout)
        self.layers = torch.nn.ModuleList()
        for layer_id in range(config.n_layers):
            self.layers.append(LlamaLayer(layer_id, config))
        self.norm = RMSNorm(config.dim, eps=config.layer_norm_eps)
        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

        # share the unembedding parameters with the embedding parameters
        # https://paperswithcode.com/method/weight-tying
        self.tok_embeddings.weight = self.output.weight

        # some useful precompute for the RoPE relative positional embeddings

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith("w3.weight") or pn.endswith("compute_output.weight"):
                torch.nn.init.normal_(
                    p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layers)
                )

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(
        self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None
    ) -> tuple[torch.Tensor, torch.Tensor]:
        _batch_size, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)  # shape: (bs, seqlen, dim)
        h = self.dropout(h)  # shape: (bs, seqlen, dim)

        for layer in self.layers:
            h = layer(h)  # shape: (bs, seqlen, dim)
        h = self.norm(h)  # shape: (bs, seqlen, dim)

        # The transformer allows us to calculate model outputs at all timesteps in parallel,
        # if the inputs for those timesteps are available. This is why h has a seqlen dimension.
        #
        # At training time, because we use teacher-forcing, we use the true inputs at each
        # timestep, so they are all available at once. So we can pass them all at once through
        # the model. As part of teacher-forcing, we also want to calculate the loss at each
        # timestep, which is why we also pass model outputs for all timesteps through the
        # output unembedding layer.
        #
        # At inference time, we usually have a few inputs available at once,
        # i.e., the conditioning/prompt and what we already have generated. We can pass these
        # through the model at once, but we only use the model output at the last timestep
        # to generate one single next token, which is why we only pass the last model output
        # through the output unembedding layer. This also means that we have to do this generation
        # in a loop, where we append what the model just generated to the input we had and we
        # repeat the process.
        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.output(h)  # shape: (bs, seqlen, vocab_size)
        else:
            # inference-time mini-optimization: only forward the output on the very last position
            logits = self.output(  # shape: (bs, seqlen, vocab_size), because we preserve dimensions
                h[:, [-1], :]
            )  # note: using list [-1] to preserve the time dim

        return logits, h

    @torch.inference_mode()
    def generate(
        self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0
    ):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        We perform this generation using basic temperature sampling. Note that we are not using
        nucleus sampling (i.e. limiting ourselves to sampling from the top-k most probable tokens
        at each timestep), though this is often used in conjunction with temperature sampling,
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        Also note this is a super inefficient version of sampling with no key/value cache.
        """
        # model.eval() and model.train() are called by the training/eval loop functions.

        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = (
                idx
                if idx.size(1) <= self.params.max_seq_len
                else idx[:, -self.params.max_seq_len :]
            )
            # forward the model to get the logits for the index in the sequence
            logits: torch.Tensor = self(idx_cond)[0]  # shape: (bs, seqlen, vocab_size)

            # Crop to just the final time step. This changes the shape to (bs, vocab_size).
            logits = logits[:, -1, :]

            if temperature == 0.0:
                # Select the single most likely index.
                idx_next = logits.argmax(dim=-1, keepdim=True)  # shape: (bs, 1)
                # We need idx_next to have the same number of dimensions as idx so we can
                # concatenate them along dim=1. This is why keepdim=True.
            else:
                """
                Perform temperature sampling:
                1) identify  the logits at the final step.
                2) scale (divide) these probabilities by the given temperature.
                3) normalize the scaled logits with a softmax to obtain scaled probabilities.
                4) sample from the scaled probability distribution.

                Note that we are not using top-k sampling/nucleus sampling in this procedure.
                """
                probabilities = F.softmax(
                    logits / temperature, dim=-1
                )  # shape: (bs, vocab_size)
                idx_next = torch.multinomial(
                    probabilities, num_samples=1
                )  # shape: (bs, 1)

            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx


def load_pretrained(checkpoint):
    device = (
        "cuda" if torch.cuda.is_available() else "cpu"
    )  # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
    # dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
    dtype = "float32"

    torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
    device_type = (
        "cuda" if "cuda" in device else "cpu"
    )  # for later use in torch.autocast
    ptdtype = {
        "float32": torch.float32,
        "bfloat16": torch.bfloat16,
        "float16": torch.float16,
    }[dtype]
    ctx = (
        nullcontext()
        if device_type == "cpu"
        else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
    )

    # init from a model saved in a specific directory
    checkpoint_dict = torch.load(checkpoint, map_location=device)
    config = LlamaConfig(**checkpoint_dict["model_args"])
    model = Llama(config)
    state_dict = checkpoint_dict["model"]
    unwanted_prefix = "_orig_mod."
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
    model.load_state_dict(state_dict, strict=False)
    return model