import torch
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
"""
Helper function to reshape frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(shape)
def apply_rotary_emb(
query: torch.Tensor,
key: torch.Tensor,
head_dim: int,
max_seq_len: int,
theta: float = 10000.0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query and key tensors. The rotation to each token
embedding is a function of that token's position in the sequence, head_dim, and theta.
The input tensors are reshaped as complex numbers to simplify your implementation.
Args:
query (torch.Tensor): Query tensor to apply rotary embeddings.
Shape: (batch_size, seqlen, n_local_heads, self.head_dim)
key (torch.Tensor): Key tensor to apply rotary embeddings.
Shape: (batch_size, seqlen, n_local_kv_heads, self.head_dim)
head_dim (int): Dimension of each attention head.
max_seq_len (int): Maximum sequence length supported by model.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
_, seqlen, _, _ = query.shape
device = query.device
dtype = query.dtype
# Please refer to slide 22 in https://phontron.com/class/anlp2024/assets/slides/anlp-05-transformers.pdf
# and Section 3 in https://arxiv.org/abs/2104.09864.
# reshape xq and xk to match the complex representation
query_real, query_imag = (
query.float().reshape(query.shape[:-1] + (-1, 2)).unbind(-1)
)
key_real, key_imag = key.float().reshape(key.shape[:-1] + (-1, 2)).unbind(-1)
# This separates each query/key vector into its odd and even indices (assuming *one-indexing*).
# query_real contains q_1, q_3, q_5, ... and query_imag contains q_2, q_4, q_6, ...
# First, compute the trigonometric values in the second and fourth columns in
# slide 22 (linked above).
# Then, combine these trigonometric values with the tensors query_real, query_imag,
# key_real, and key_imag.
# Compute thetas, where theta varies along head_dim.
assert head_dim % 2 == 0
theta_exps = torch.arange(
0, head_dim // 2, device=device, dtype=dtype, requires_grad=False
) * (-2 / head_dim)
theta_seq_half = torch.pow(
torch.tensor(theta, device=device, dtype=dtype, requires_grad=False), theta_exps
) # shape (head_dim // 2)
# Compute the m-theta matrix, where m varies along seqlen and theta varies along head_dim.
m_seq = torch.arange(
0, seqlen, device=device, dtype=dtype, requires_grad=False
) # shape (seqlen)
m_theta_half = torch.matmul(
m_seq.unsqueeze(-1), theta_seq_half.unsqueeze(0)
) # shape (seqlen, head_dim // 2)
# Repeat interleave the columns so each theta appears twice.
# Do this after matmul to avoid unnecessary computation.
m_theta = torch.repeat_interleave(
m_theta_half, 2, dim=-1
) # shape (seqlen, head_dim)
# Compute the sin and cos values.
cos_m_theta = torch.cos(m_theta) # shape (seqlen, head_dim)
sin_m_theta = torch.sin(m_theta) # shape (seqlen, head_dim)
broadcast_cos_m_theta = reshape_for_broadcast(
cos_m_theta, query
) # shape(1, seqlen, 1, head_dim)
broadcast_sin_m_theta = reshape_for_broadcast(
sin_m_theta, query
) # shape(1, seqlen, 1, head_dim)
# Interleave the real and imaginary parts of the query and key tensors.
query_interleaved = interleave(
-1 * query_imag, query_real
) # shape (batch_size, seqlen, n_local_heads, head_dim)
key_interleaved = interleave(
-1 * key_imag, key_real
) # shape (batch_size, seqlen, n_local_heads, head_dim)
# Compute the rotary embeddings for the query and key tensors.
query_out = (query * broadcast_cos_m_theta) + (
query_interleaved * broadcast_sin_m_theta
)
key_out = (key * broadcast_cos_m_theta) + (key_interleaved * broadcast_sin_m_theta)
return query_out, key_out
def interleave(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Interleave two tensors along the last dimension.
Args:
x (torch.Tensor): First tensor to interleave.
y (torch.Tensor): Second tensor to interleave.
Returns:
torch.Tensor: Interleaved tensor.
"""
return torch.stack((x, y), dim=-1).reshape(*x.shape[:-1], -1)