import torch def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): """ Helper function to reshape frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations. Args: freqs_cis (torch.Tensor): Frequency tensor to be reshaped. x (torch.Tensor): Target tensor for broadcasting compatibility. Returns: torch.Tensor: Reshaped frequency tensor. Raises: AssertionError: If the frequency tensor doesn't match the expected shape. AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. """ ndim = x.ndim assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(shape) def apply_rotary_emb( query: torch.Tensor, key: torch.Tensor, head_dim: int, max_seq_len: int, theta: float = 10000.0, ) -> tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query and key tensors. The rotation to each token embedding is a function of that token's position in the sequence, head_dim, and theta. The input tensors are reshaped as complex numbers to simplify your implementation. Args: query (torch.Tensor): Query tensor to apply rotary embeddings. Shape: (batch_size, seqlen, n_local_heads, self.head_dim) key (torch.Tensor): Key tensor to apply rotary embeddings. Shape: (batch_size, seqlen, n_local_kv_heads, self.head_dim) head_dim (int): Dimension of each attention head. max_seq_len (int): Maximum sequence length supported by model. Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ _, seqlen, _, _ = query.shape device = query.device dtype = query.dtype # Please refer to slide 22 in https://phontron.com/class/anlp2024/assets/slides/anlp-05-transformers.pdf # and Section 3 in https://arxiv.org/abs/2104.09864. # reshape xq and xk to match the complex representation query_real, query_imag = ( query.float().reshape(query.shape[:-1] + (-1, 2)).unbind(-1) ) key_real, key_imag = key.float().reshape(key.shape[:-1] + (-1, 2)).unbind(-1) # This separates each query/key vector into its odd and even indices (assuming *one-indexing*). # query_real contains q_1, q_3, q_5, ... and query_imag contains q_2, q_4, q_6, ... # First, compute the trigonometric values in the second and fourth columns in # slide 22 (linked above). # Then, combine these trigonometric values with the tensors query_real, query_imag, # key_real, and key_imag. # Compute thetas, where theta varies along head_dim. assert head_dim % 2 == 0 theta_exps = torch.arange( 0, head_dim // 2, device=device, dtype=dtype, requires_grad=False ) * (-2 / head_dim) theta_seq_half = torch.pow( torch.tensor(theta, device=device, dtype=dtype, requires_grad=False), theta_exps ) # shape (head_dim // 2) # Compute the m-theta matrix, where m varies along seqlen and theta varies along head_dim. m_seq = torch.arange( 0, seqlen, device=device, dtype=dtype, requires_grad=False ) # shape (seqlen) m_theta_half = torch.matmul( m_seq.unsqueeze(-1), theta_seq_half.unsqueeze(0) ) # shape (seqlen, head_dim // 2) # Repeat interleave the columns so each theta appears twice. # Do this after matmul to avoid unnecessary computation. m_theta = torch.repeat_interleave( m_theta_half, 2, dim=-1 ) # shape (seqlen, head_dim) # Compute the sin and cos values. cos_m_theta = torch.cos(m_theta) # shape (seqlen, head_dim) sin_m_theta = torch.sin(m_theta) # shape (seqlen, head_dim) broadcast_cos_m_theta = reshape_for_broadcast( cos_m_theta, query ) # shape(1, seqlen, 1, head_dim) broadcast_sin_m_theta = reshape_for_broadcast( sin_m_theta, query ) # shape(1, seqlen, 1, head_dim) # Interleave the real and imaginary parts of the query and key tensors. query_interleaved = interleave( -1 * query_imag, query_real ) # shape (batch_size, seqlen, n_local_heads, head_dim) key_interleaved = interleave( -1 * key_imag, key_real ) # shape (batch_size, seqlen, n_local_heads, head_dim) # Compute the rotary embeddings for the query and key tensors. query_out = (query * broadcast_cos_m_theta) + ( query_interleaved * broadcast_sin_m_theta ) key_out = (key * broadcast_cos_m_theta) + (key_interleaved * broadcast_sin_m_theta) return query_out, key_out def interleave(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Interleave two tensors along the last dimension. Args: x (torch.Tensor): First tensor to interleave. y (torch.Tensor): Second tensor to interleave. Returns: torch.Tensor: Interleaved tensor. """ return torch.stack((x, y), dim=-1).reshape(*x.shape[:-1], -1)