mini-llama / src / rope.py
rope.py
Raw
import torch


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    """
    Helper function to reshape frequency tensor to have the same shape as the target tensor 'x'
    for the purpose of broadcasting the frequency tensor during element-wise operations.

    Args:
        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
        x (torch.Tensor): Target tensor for broadcasting compatibility.

    Returns:
        torch.Tensor: Reshaped frequency tensor.

    Raises:
        AssertionError: If the frequency tensor doesn't match the expected shape.
        AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
    """
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(shape)


def apply_rotary_emb(
    query: torch.Tensor,
    key: torch.Tensor,
    head_dim: int,
    max_seq_len: int,
    theta: float = 10000.0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query and key tensors. The rotation to each token
    embedding is a function of that token's position in the sequence, head_dim, and theta.
    The input tensors are reshaped as complex numbers to simplify your implementation.

    Args:
        query (torch.Tensor): Query tensor to apply rotary embeddings.
                              Shape: (batch_size, seqlen, n_local_heads, self.head_dim)
        key (torch.Tensor): Key tensor to apply rotary embeddings.
                              Shape: (batch_size, seqlen, n_local_kv_heads, self.head_dim)
        head_dim (int): Dimension of each attention head.
        max_seq_len (int): Maximum sequence length supported by model.
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
    """

    _, seqlen, _, _ = query.shape
    device = query.device
    dtype = query.dtype

    # Please refer to slide 22 in https://phontron.com/class/anlp2024/assets/slides/anlp-05-transformers.pdf
    # and Section 3 in https://arxiv.org/abs/2104.09864.

    # reshape xq and xk to match the complex representation
    query_real, query_imag = (
        query.float().reshape(query.shape[:-1] + (-1, 2)).unbind(-1)
    )
    key_real, key_imag = key.float().reshape(key.shape[:-1] + (-1, 2)).unbind(-1)
    # This separates each query/key vector into its odd and even indices (assuming *one-indexing*).
    # query_real contains q_1, q_3, q_5, ... and query_imag contains q_2, q_4, q_6, ...

    # First, compute the trigonometric values in the second and fourth columns in
    # slide 22 (linked above).

    # Then, combine these trigonometric values with the tensors query_real, query_imag,
    # key_real, and key_imag.

    # Compute thetas, where theta varies along head_dim.
    assert head_dim % 2 == 0

    theta_exps = torch.arange(
        0, head_dim // 2, device=device, dtype=dtype, requires_grad=False
    ) * (-2 / head_dim)
    theta_seq_half = torch.pow(
        torch.tensor(theta, device=device, dtype=dtype, requires_grad=False), theta_exps
    )  # shape (head_dim // 2)

    # Compute the m-theta matrix, where m varies along seqlen and theta varies along head_dim.
    m_seq = torch.arange(
        0, seqlen, device=device, dtype=dtype, requires_grad=False
    )  # shape (seqlen)
    m_theta_half = torch.matmul(
        m_seq.unsqueeze(-1), theta_seq_half.unsqueeze(0)
    )  # shape (seqlen, head_dim // 2)

    # Repeat interleave the columns so each theta appears twice.
    # Do this after matmul to avoid unnecessary computation.
    m_theta = torch.repeat_interleave(
        m_theta_half, 2, dim=-1
    )  # shape (seqlen, head_dim)

    # Compute the sin and cos values.
    cos_m_theta = torch.cos(m_theta)  # shape (seqlen, head_dim)
    sin_m_theta = torch.sin(m_theta)  # shape (seqlen, head_dim)

    broadcast_cos_m_theta = reshape_for_broadcast(
        cos_m_theta, query
    )  # shape(1, seqlen, 1, head_dim)
    broadcast_sin_m_theta = reshape_for_broadcast(
        sin_m_theta, query
    )  # shape(1, seqlen, 1, head_dim)

    # Interleave the real and imaginary parts of the query and key tensors.
    query_interleaved = interleave(
        -1 * query_imag, query_real
    )  # shape (batch_size, seqlen, n_local_heads, head_dim)
    key_interleaved = interleave(
        -1 * key_imag, key_real
    )  # shape (batch_size, seqlen, n_local_heads, head_dim)

    # Compute the rotary embeddings for the query and key tensors.
    query_out = (query * broadcast_cos_m_theta) + (
        query_interleaved * broadcast_sin_m_theta
    )
    key_out = (key * broadcast_cos_m_theta) + (key_interleaved * broadcast_sin_m_theta)

    return query_out, key_out


def interleave(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Interleave two tensors along the last dimension.

    Args:
        x (torch.Tensor): First tensor to interleave.
        y (torch.Tensor): Second tensor to interleave.

    Returns:
        torch.Tensor: Interleaved tensor.
    """
    return torch.stack((x, y), dim=-1).reshape(*x.shape[:-1], -1)