Rotary Position Embeddings (RoPE)
Rotary Position Embeddings (RoPE) is an elegant method for encoding positional information in transformers. Unlike absolute position embeddings, RoPE encodes relative positions through rotation, enabling better length generalization and improved performance.
The Position Encoding Problem
Why We Need Position Information
Transformers process all tokens in parallel, so they need explicit position information:
import torch
import numpy as np
# Without position info, these are identical to the model:
sequence1 = ["the", "cat", "sat"]
sequence2 = ["sat", "the", "cat"]
# Embeddings (without position)
vocab = {"the": 0, "cat": 1, "sat": 2}
embeddings = torch.randn(3, 512) # vocab_size=3, dim=512
# Same bag of embeddings!
emb1 = torch.stack([embeddings[vocab[w]] for w in sequence1])
emb2 = torch.stack([embeddings[vocab[w]] for w in sequence2])
# The sum is the same (order doesn't matter)
print(f"Sum of sequence 1: {emb1.sum(dim=0)[:5]}")
print(f"Sum of sequence 2: {emb2.sum(dim=0)[:5]}")
print(f"Are sums equal? {torch.allclose(emb1.sum(dim=0), emb2.sum(dim=0))}")
Traditional Approaches
1. Absolute Position Embeddings (Original Transformer):
def sinusoidal_position_encoding(max_len, d_model):
"""Original transformer position encoding."""
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
pos_enc = torch.zeros(max_len, d_model)
pos_enc[:, 0::2] = torch.sin(position * div_term)
pos_enc[:, 1::2] = torch.cos(position * div_term)
return pos_enc
# Add to embeddings
pos_enc = sinusoidal_position_encoding(max_len=100, d_model=512)
embeddings_with_pos = emb1 + pos_enc[:len(sequence1)]
Problem: Adds position globally, doesn't naturally encode relative distances in attention.
2. Learned Position Embeddings (BERT, GPT):
class LearnedPositionEmbedding(torch.nn.Module):
def __init__(self, max_len, d_model):
super().__init__()
self.pos_embedding = torch.nn.Embedding(max_len, d_model)
def forward(self, x):
seq_len = x.size(1)
positions = torch.arange(seq_len, device=x.device)
return x + self.pos_embedding(positions)
Problem: Doesn't generalize to sequences longer than
max_lenLength Generalization Challenge:
Models with absolute position embeddings struggle with sequences longer than training length. A model trained on 2048 tokens performs poorly on 4096 tokens, even though the content is similar.
RoPE: The Rotary Solution
Core Idea
Instead of adding position information, rotate the query and key vectors by an angle proportional to their position.
Key insight: The dot product between rotated vectors naturally encodes relative position!
q_m = R(m) × q (rotate query at position m)
k_n = R(n) × k (rotate key at position n)
q_m · k_n = (R(m) × q) · (R(n) × k)
= q · R(m-n) × k (depends on relative position m-n!)
Geometric Intuition:
Imagine vectors on a unit circle. Rotating q clockwise by angle θ_m and k clockwise by angle θ_n makes their dot product depend on the relative angle (θ_m - θ_n), which encodes relative position!
Mathematical Foundation
2D Rotation Matrix
In 2D, rotating a vector by angle θ:
R(θ) = [cos(θ) -sin(θ)]
[sin(θ) cos(θ)]
RoPE for Higher Dimensions
For d-dimensional vectors, apply 2D rotations to pairs of dimensions:
import torch
import math
def rope_rotation_matrix(seq_len, dim, base=10000):
"""
Generate RoPE rotation matrices.
Args:
seq_len: Sequence length
dim: Embedding dimension (must be even)
base: Base for frequency (10000 in paper)
Returns:
cos, sin: (seq_len, dim) - cosine and sine values
"""
# Frequency for each dimension pair
# θ_i = base^(-2i/d) for i = 0, 1, ..., d/2-1
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
# Shape: (dim // 2,)
# Position indices
positions = torch.arange(seq_len).float()
# Shape: (seq_len,)
# Compute angles: outer product of positions and frequencies
freqs = torch.einsum('i,j->ij', positions, inv_freq)
# Shape: (seq_len, dim // 2)
# Duplicate for pairs
emb = torch.cat([freqs, freqs], dim=-1)
# Shape: (seq_len, dim)
# Compute cos and sin
cos = emb.cos()
sin = emb.sin()
return cos, sin
# Visualize RoPE frequencies
seq_len = 100
dim = 64
cos, sin = rope_rotation_matrix(seq_len, dim)
import matplotlib.pyplot as plt
plt.figure(figsize=(14, 5))
plt.subplot(1, 2, 1)
plt.imshow(cos.numpy(), aspect='auto', cmap='RdBu')
plt.colorbar()
plt.title('RoPE Cosine Values')
plt.xlabel('Dimension')
plt.ylabel('Position')
plt.subplot(1, 2, 2)
plt.imshow(sin.numpy(), aspect='auto', cmap='RdBu')
plt.colorbar()
plt.title('RoPE Sine Values')
plt.xlabel('Dimension')
plt.ylabel('Position')
plt.tight_layout()
plt.show()
print(f"Frequency range: {inv_freq[0]:.6f} to {inv_freq[-1]:.6f}")
Applying RoPE
def apply_rope(x, cos, sin):
"""
Apply rotary embeddings to input.
Args:
x: Input tensor (batch, seq_len, dim) or (batch, heads, seq_len, dim)
cos, sin: Rotation values (seq_len, dim)
Returns:
Rotated tensor (same shape as x)
"""
# Handle both (batch, seq_len, dim) and (batch, heads, seq_len, dim)
if x.dim() == 4:
# Add head dimension to cos/sin
cos = cos[None, :, None, :] # (1, seq_len, 1, dim)
sin = sin[None, :, None, :]
else:
cos = cos[None, :, :] # (1, seq_len, dim)
sin = sin[None, :, :]
# Split into pairs and rotate
x1, x2 = x.chunk(2, dim=-1)
# Apply rotation
# x_rotated = x * cos + rotate_half(x) * sin
rotated_x = torch.cat([
x1 * cos[..., :x1.shape[-1]] - x2 * sin[..., x2.shape[-1]:],
x2 * cos[..., x2.shape[-1]:] + x1 * sin[..., :x1.shape[-1]]
], dim=-1)
return rotated_x
# Alternative: More efficient implementation
def rotate_half(x):
"""Rotate half the dimensions of x."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def apply_rope_efficient(x, cos, sin):
"""Efficient RoPE application."""
# Match dimensions
if x.dim() == 4:
cos = cos[None, :, None, :]
sin = sin[None, :, None, :]
else:
cos = cos[None, :, :]
sin = sin[None, :, :]
# Truncate to sequence length
seq_len = x.shape[-2] if x.dim() == 4 else x.shape[1]
cos = cos[..., :seq_len, :]
sin = sin[..., :seq_len, :]
return (x * cos) + (rotate_half(x) * sin)
# Test RoPE
batch, seq_len, dim = 2, 10, 64
x = torch.randn(batch, seq_len, dim)
cos, sin = rope_rotation_matrix(seq_len, dim)
x_rotated = apply_rope_efficient(x, cos, sin)
print(f"Input shape: {x.shape}")
print(f"Output shape: {x_rotated.shape}")
# Verify rotation preserves norm
norm_before = torch.norm(x, dim=-1).mean()
norm_after = torch.norm(x_rotated, dim=-1).mean()
print(f"\nNorm before: {norm_before:.4f}")
print(f"Norm after: {norm_after:.4f}")
print(f"Norm preserved: {torch.allclose(norm_before, norm_after, atol=1e-5)}")
Why Rotation Preserves Norms:
Rotation matrices are orthogonal: R^T R = I. This means rotating a vector doesn't change its length, only its direction. This is important because attention shouldn't change based on absolute position, only relative position.
Complete RoPE Implementation
class RotaryPositionEmbedding(torch.nn.Module):
"""
Rotary Position Embeddings.
Used in LLaMA, GPT-Neo, GPT-J, PaLM, and many modern LLMs.
"""
def __init__(self, dim, max_seq_len=2048, base=10000):
"""
Args:
dim: Dimension per attention head
max_seq_len: Maximum sequence length to precompute
base: Base for frequency computation
"""
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# Precompute frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# Precompute cos and sin for all positions
self._set_cos_sin_cache(max_seq_len)
def _set_cos_sin_cache(self, seq_len):
"""Precompute and cache cos/sin values."""
self.max_seq_len_cached = seq_len
t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer('cos_cached', emb.cos(), persistent=False)
self.register_buffer('sin_cached', emb.sin(), persistent=False)
def forward(self, q, k, seq_len=None):
"""
Apply RoPE to query and key tensors.
Args:
q: Query (batch, heads, seq_len, dim)
k: Key (batch, heads, seq_len, dim)
seq_len: Sequence length (optional, inferred from q if None)
Returns:
q_rotated, k_rotated: Rotated query and key
"""
if seq_len is None:
seq_len = q.shape[-2]
# Extend cache if needed
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len)
# Get cached values
cos = self.cos_cached[:seq_len, :]
sin = self.sin_cached[:seq_len, :]
# Add dimensions for batch and heads
cos = cos[None, None, :, :] # (1, 1, seq_len, dim)
sin = sin[None, None, :, :]
# Apply rotation
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
q_rotated = (q * cos) + (rotate_half(q) * sin)
k_rotated = (k * cos) + (rotate_half(k) * sin)
return q_rotated, k_rotated
# Example usage
batch, heads, seq_len, head_dim = 2, 8, 16, 64
Q = torch.randn(batch, heads, seq_len, head_dim)
K = torch.randn(batch, heads, seq_len, head_dim)
rope = RotaryPositionEmbedding(dim=head_dim)
Q_rot, K_rot = rope(Q, K)
print(f"Q shape: {Q.shape}")
print(f"Q_rot shape: {Q_rot.shape}")
# Test attention with and without RoPE
def attention(q, k, v):
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
attn = torch.softmax(scores, dim=-1)
return torch.matmul(attn, v)
V = torch.randn(batch, heads, seq_len, head_dim)
# Without RoPE
output_no_rope = attention(Q, K, V)
# With RoPE
output_with_rope = attention(Q_rot, K_rot, V)
print(f"\nOutput (no RoPE): {output_no_rope.shape}")
print(f"Output (with RoPE): {output_with_rope.shape}")
Why RoPE Works
Relative Position Encoding
The magic of RoPE: when you compute attention, relative position naturally emerges:
def demonstrate_relative_position():
"""
Show how RoPE encodes relative position in attention.
"""
dim = 64
base = 10000
# Single head, simple case
q = torch.randn(1, dim)
k = torch.randn(1, dim)
# Positions
pos_q = 5
pos_k = 3
# Compute rotation angles
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
# Angles for each position
angle_q = pos_q * inv_freq
angle_k = pos_k * inv_freq
# Relative angle
angle_relative = (pos_q - pos_k) * inv_freq
print("Rotation angles:")
print(f" Position {pos_q}: {angle_q[:3]}")
print(f" Position {pos_k}: {angle_k[:3]}")
print(f" Relative (q-k): {angle_relative[:3]}")
print(f" Difference matches: {torch.allclose(angle_q - angle_k, angle_relative)}")
# The key insight: q_rot · k_rot depends only on (pos_q - pos_k)
# Not on absolute positions!
demonstrate_relative_position()
Length Extrapolation
RoPE generalizes better to unseen lengths:
def test_length_extrapolation():
"""
Test RoPE on sequences longer than training length.
"""
# Train on seq_len=512
rope_train = RotaryPositionEmbedding(dim=64, max_seq_len=512)
# Test on seq_len=1024 (2x longer)
batch, heads, seq_len_test, dim = 1, 1, 1024, 64
Q = torch.randn(batch, heads, seq_len_test, dim)
K = torch.randn(batch, heads, seq_len_test, dim)
# RoPE automatically extends
Q_rot, K_rot = rope_train(Q, K, seq_len=seq_len_test)
print(f"Trained on seq_len: 512")
print(f"Testing on seq_len: {seq_len_test}")
print(f"Output shape: {Q_rot.shape}")
print("RoPE successfully extrapolated to 2x length!")
test_length_extrapolation()
Length Generalization:
Because RoPE uses trigonometric functions with continuous frequencies, it can naturally extrapolate to any sequence length. The same rotation patterns continue smoothly beyond the training length.
This is why LLaMA trained on 2048 tokens can handle 4096+ tokens reasonably well.
RoPE Variants
1. Partial RoPE
Apply RoPE to only part of the dimensions:
class PartialRoPE(torch.nn.Module):
"""Apply RoPE to only a fraction of dimensions."""
def __init__(self, dim, rope_fraction=0.5, max_seq_len=2048):
super().__init__()
rope_dim = int(dim * rope_fraction)
self.rope_dim = rope_dim
self.rope = RotaryPositionEmbedding(rope_dim, max_seq_len)
def forward(self, q, k):
# Split into RoPE and non-RoPE parts
q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:]
k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:]
# Apply RoPE to first part
q_rope, k_rope = self.rope(q_rope, k_rope)
# Concatenate
q_out = torch.cat([q_rope, q_pass], dim=-1)
k_out = torch.cat([k_rope, k_pass], dim=-1)
return q_out, k_out
2. Dynamic RoPE Scaling
Scale base frequency for longer contexts:
def scaled_rope(dim, max_seq_len, scale_factor=1.0, base=10000):
"""
RoPE with scaled frequencies for longer contexts.
Args:
scale_factor: Multiply base by this (>1 for longer contexts)
"""
scaled_base = base * scale_factor
return RotaryPositionEmbedding(dim, max_seq_len, base=scaled_base)
# Example: 2x longer context with scaled RoPE
rope_standard = RotaryPositionEmbedding(64, max_seq_len=2048)
rope_scaled = scaled_rope(64, max_seq_len=4096, scale_factor=2.0)
print("Standard RoPE: trained on 2048, struggles at 4096")
print("Scaled RoPE: uses scale_factor=2.0 for 4096 tokens")
Models Using RoPE
models_with_rope = {
'LLaMA': 'All sizes (7B-65B)',
'LLaMA 2': 'All sizes (7B-70B)',
'GPT-Neo': '1.3B, 2.7B',
'GPT-J': '6B',
'PaLM': 'All sizes (8B-540B)',
'CodeGen': 'All sizes',
'Mistral': '7B',
'Mixtral': '8x7B',
}
print("Models using Rotary Position Embeddings:\n")
for model, sizes in models_with_rope.items():
print(f" {model:15s}: {sizes}")
Summary
Rotary Position Embeddings encode position through rotation:
Key Advantages:
- Relative position: Naturally encodes relative distances in attention
- Length extrapolation: Generalizes to longer sequences than training
- Parameter-free: No learned parameters, purely algorithmic
- Norm-preserving: Rotation doesn't change vector magnitudes
- Efficient: Can be precomputed and cached
How it Works:
- Rotate query and key by angle proportional to position
- Attention score depends on relative angle (relative position)
- Uses different frequencies for different dimension pairs
When to Use:
- Default choice for modern decoder-only models
- Especially good for models that need length generalization
- Works best with causal (autoregressive) attention
RoPE has become the standard position encoding for modern LLMs, replacing both absolute position embeddings and learned embeddings in most new architectures.