Back
advanced
Modern Architectures

Rotary Position Embeddings (RoPE)

Master rotary position embeddings - the elegant position encoding method that enables length generalization and powers modern LLMs like LLaMA, GPT-Neo, and PaLM.

20 min read· RoPE· Position Encoding· Embeddings· LLaMA

Rotary Position Embeddings (RoPE)

Rotary Position Embeddings (RoPE) is an elegant method for encoding positional information in transformers. Unlike absolute position embeddings, RoPE encodes relative positions through rotation, enabling better length generalization and improved performance.

The Position Encoding Problem

Why We Need Position Information

Transformers process all tokens in parallel, so they need explicit position information:

python
import torch
import numpy as np

# Without position info, these are identical to the model:
sequence1 = ["the", "cat", "sat"]
sequence2 = ["sat", "the", "cat"]

# Embeddings (without position)
vocab = {"the": 0, "cat": 1, "sat": 2}
embeddings = torch.randn(3, 512)  # vocab_size=3, dim=512

# Same bag of embeddings!
emb1 = torch.stack([embeddings[vocab[w]] for w in sequence1])
emb2 = torch.stack([embeddings[vocab[w]] for w in sequence2])

# The sum is the same (order doesn't matter)
print(f"Sum of sequence 1: {emb1.sum(dim=0)[:5]}")
print(f"Sum of sequence 2: {emb2.sum(dim=0)[:5]}")
print(f"Are sums equal? {torch.allclose(emb1.sum(dim=0), emb2.sum(dim=0))}")

Traditional Approaches

1. Absolute Position Embeddings (Original Transformer):

python
def sinusoidal_position_encoding(max_len, d_model):
    """Original transformer position encoding."""
    position = torch.arange(max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))

    pos_enc = torch.zeros(max_len, d_model)
    pos_enc[:, 0::2] = torch.sin(position * div_term)
    pos_enc[:, 1::2] = torch.cos(position * div_term)

    return pos_enc

# Add to embeddings
pos_enc = sinusoidal_position_encoding(max_len=100, d_model=512)
embeddings_with_pos = emb1 + pos_enc[:len(sequence1)]

Problem: Adds position globally, doesn't naturally encode relative distances in attention.

2. Learned Position Embeddings (BERT, GPT):

python
class LearnedPositionEmbedding(torch.nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.pos_embedding = torch.nn.Embedding(max_len, d_model)

    def forward(self, x):
        seq_len = x.size(1)
        positions = torch.arange(seq_len, device=x.device)
        return x + self.pos_embedding(positions)

Problem: Doesn't generalize to sequences longer than

max_len
seen during training.

Length Generalization Challenge:

Models with absolute position embeddings struggle with sequences longer than training length. A model trained on 2048 tokens performs poorly on 4096 tokens, even though the content is similar.

RoPE: The Rotary Solution

Core Idea

Instead of adding position information, rotate the query and key vectors by an angle proportional to their position.

Key insight: The dot product between rotated vectors naturally encodes relative position!

q_m = R(m) × q    (rotate query at position m)
k_n = R(n) × k    (rotate key at position n)

q_m · k_n = (R(m) × q) · (R(n) × k)
          = q · R(m-n) × k    (depends on relative position m-n!)

Geometric Intuition:

Imagine vectors on a unit circle. Rotating q clockwise by angle θ_m and k clockwise by angle θ_n makes their dot product depend on the relative angle (θ_m - θ_n), which encodes relative position!

Mathematical Foundation

2D Rotation Matrix

In 2D, rotating a vector by angle θ:

R(θ) = [cos(θ)  -sin(θ)]
       [sin(θ)   cos(θ)]

RoPE for Higher Dimensions

For d-dimensional vectors, apply 2D rotations to pairs of dimensions:

python
import torch
import math

def rope_rotation_matrix(seq_len, dim, base=10000):
    """
    Generate RoPE rotation matrices.

    Args:
        seq_len: Sequence length
        dim: Embedding dimension (must be even)
        base: Base for frequency (10000 in paper)

    Returns:
        cos, sin: (seq_len, dim) - cosine and sine values
    """
    # Frequency for each dimension pair
    # θ_i = base^(-2i/d) for i = 0, 1, ..., d/2-1
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    # Shape: (dim // 2,)

    # Position indices
    positions = torch.arange(seq_len).float()
    # Shape: (seq_len,)

    # Compute angles: outer product of positions and frequencies
    freqs = torch.einsum('i,j->ij', positions, inv_freq)
    # Shape: (seq_len, dim // 2)

    # Duplicate for pairs
    emb = torch.cat([freqs, freqs], dim=-1)
    # Shape: (seq_len, dim)

    # Compute cos and sin
    cos = emb.cos()
    sin = emb.sin()

    return cos, sin


# Visualize RoPE frequencies
seq_len = 100
dim = 64

cos, sin = rope_rotation_matrix(seq_len, dim)

import matplotlib.pyplot as plt

plt.figure(figsize=(14, 5))

plt.subplot(1, 2, 1)
plt.imshow(cos.numpy(), aspect='auto', cmap='RdBu')
plt.colorbar()
plt.title('RoPE Cosine Values')
plt.xlabel('Dimension')
plt.ylabel('Position')

plt.subplot(1, 2, 2)
plt.imshow(sin.numpy(), aspect='auto', cmap='RdBu')
plt.colorbar()
plt.title('RoPE Sine Values')
plt.xlabel('Dimension')
plt.ylabel('Position')

plt.tight_layout()
plt.show()

print(f"Frequency range: {inv_freq[0]:.6f} to {inv_freq[-1]:.6f}")

Applying RoPE

python
def apply_rope(x, cos, sin):
    """
    Apply rotary embeddings to input.

    Args:
        x: Input tensor (batch, seq_len, dim) or (batch, heads, seq_len, dim)
        cos, sin: Rotation values (seq_len, dim)

    Returns:
        Rotated tensor (same shape as x)
    """
    # Handle both (batch, seq_len, dim) and (batch, heads, seq_len, dim)
    if x.dim() == 4:
        # Add head dimension to cos/sin
        cos = cos[None, :, None, :]  # (1, seq_len, 1, dim)
        sin = sin[None, :, None, :]
    else:
        cos = cos[None, :, :]  # (1, seq_len, dim)
        sin = sin[None, :, :]

    # Split into pairs and rotate
    x1, x2 = x.chunk(2, dim=-1)

    # Apply rotation
    # x_rotated = x * cos + rotate_half(x) * sin
    rotated_x = torch.cat([
        x1 * cos[..., :x1.shape[-1]] - x2 * sin[..., x2.shape[-1]:],
        x2 * cos[..., x2.shape[-1]:] + x1 * sin[..., :x1.shape[-1]]
    ], dim=-1)

    return rotated_x


# Alternative: More efficient implementation
def rotate_half(x):
    """Rotate half the dimensions of x."""
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat([-x2, x1], dim=-1)


def apply_rope_efficient(x, cos, sin):
    """Efficient RoPE application."""
    # Match dimensions
    if x.dim() == 4:
        cos = cos[None, :, None, :]
        sin = sin[None, :, None, :]
    else:
        cos = cos[None, :, :]
        sin = sin[None, :, :]

    # Truncate to sequence length
    seq_len = x.shape[-2] if x.dim() == 4 else x.shape[1]
    cos = cos[..., :seq_len, :]
    sin = sin[..., :seq_len, :]

    return (x * cos) + (rotate_half(x) * sin)


# Test RoPE
batch, seq_len, dim = 2, 10, 64
x = torch.randn(batch, seq_len, dim)

cos, sin = rope_rotation_matrix(seq_len, dim)
x_rotated = apply_rope_efficient(x, cos, sin)

print(f"Input shape: {x.shape}")
print(f"Output shape: {x_rotated.shape}")

# Verify rotation preserves norm
norm_before = torch.norm(x, dim=-1).mean()
norm_after = torch.norm(x_rotated, dim=-1).mean()
print(f"\nNorm before: {norm_before:.4f}")
print(f"Norm after: {norm_after:.4f}")
print(f"Norm preserved: {torch.allclose(norm_before, norm_after, atol=1e-5)}")

Why Rotation Preserves Norms:

Rotation matrices are orthogonal: R^T R = I. This means rotating a vector doesn't change its length, only its direction. This is important because attention shouldn't change based on absolute position, only relative position.

Complete RoPE Implementation

python
class RotaryPositionEmbedding(torch.nn.Module):
    """
    Rotary Position Embeddings.

    Used in LLaMA, GPT-Neo, GPT-J, PaLM, and many modern LLMs.
    """

    def __init__(self, dim, max_seq_len=2048, base=10000):
        """
        Args:
            dim: Dimension per attention head
            max_seq_len: Maximum sequence length to precompute
            base: Base for frequency computation
        """
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        # Precompute frequencies
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        # Precompute cos and sin for all positions
        self._set_cos_sin_cache(max_seq_len)

    def _set_cos_sin_cache(self, seq_len):
        """Precompute and cache cos/sin values."""
        self.max_seq_len_cached = seq_len

        t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)

        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer('cos_cached', emb.cos(), persistent=False)
        self.register_buffer('sin_cached', emb.sin(), persistent=False)

    def forward(self, q, k, seq_len=None):
        """
        Apply RoPE to query and key tensors.

        Args:
            q: Query (batch, heads, seq_len, dim)
            k: Key (batch, heads, seq_len, dim)
            seq_len: Sequence length (optional, inferred from q if None)

        Returns:
            q_rotated, k_rotated: Rotated query and key
        """
        if seq_len is None:
            seq_len = q.shape[-2]

        # Extend cache if needed
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len)

        # Get cached values
        cos = self.cos_cached[:seq_len, :]
        sin = self.sin_cached[:seq_len, :]

        # Add dimensions for batch and heads
        cos = cos[None, None, :, :]  # (1, 1, seq_len, dim)
        sin = sin[None, None, :, :]

        # Apply rotation
        def rotate_half(x):
            x1, x2 = x.chunk(2, dim=-1)
            return torch.cat([-x2, x1], dim=-1)

        q_rotated = (q * cos) + (rotate_half(q) * sin)
        k_rotated = (k * cos) + (rotate_half(k) * sin)

        return q_rotated, k_rotated


# Example usage
batch, heads, seq_len, head_dim = 2, 8, 16, 64

Q = torch.randn(batch, heads, seq_len, head_dim)
K = torch.randn(batch, heads, seq_len, head_dim)

rope = RotaryPositionEmbedding(dim=head_dim)
Q_rot, K_rot = rope(Q, K)

print(f"Q shape: {Q.shape}")
print(f"Q_rot shape: {Q_rot.shape}")

# Test attention with and without RoPE
def attention(q, k, v):
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
    attn = torch.softmax(scores, dim=-1)
    return torch.matmul(attn, v)

V = torch.randn(batch, heads, seq_len, head_dim)

# Without RoPE
output_no_rope = attention(Q, K, V)

# With RoPE
output_with_rope = attention(Q_rot, K_rot, V)

print(f"\nOutput (no RoPE): {output_no_rope.shape}")
print(f"Output (with RoPE): {output_with_rope.shape}")

Why RoPE Works

Relative Position Encoding

The magic of RoPE: when you compute attention, relative position naturally emerges:

python
def demonstrate_relative_position():
    """
    Show how RoPE encodes relative position in attention.
    """
    dim = 64
    base = 10000

    # Single head, simple case
    q = torch.randn(1, dim)
    k = torch.randn(1, dim)

    # Positions
    pos_q = 5
    pos_k = 3

    # Compute rotation angles
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))

    # Angles for each position
    angle_q = pos_q * inv_freq
    angle_k = pos_k * inv_freq

    # Relative angle
    angle_relative = (pos_q - pos_k) * inv_freq

    print("Rotation angles:")
    print(f"  Position {pos_q}: {angle_q[:3]}")
    print(f"  Position {pos_k}: {angle_k[:3]}")
    print(f"  Relative (q-k): {angle_relative[:3]}")
    print(f"  Difference matches: {torch.allclose(angle_q - angle_k, angle_relative)}")

    # The key insight: q_rot · k_rot depends only on (pos_q - pos_k)
    # Not on absolute positions!

demonstrate_relative_position()

Length Extrapolation

RoPE generalizes better to unseen lengths:

python
def test_length_extrapolation():
    """
    Test RoPE on sequences longer than training length.
    """
    # Train on seq_len=512
    rope_train = RotaryPositionEmbedding(dim=64, max_seq_len=512)

    # Test on seq_len=1024 (2x longer)
    batch, heads, seq_len_test, dim = 1, 1, 1024, 64

    Q = torch.randn(batch, heads, seq_len_test, dim)
    K = torch.randn(batch, heads, seq_len_test, dim)

    # RoPE automatically extends
    Q_rot, K_rot = rope_train(Q, K, seq_len=seq_len_test)

    print(f"Trained on seq_len: 512")
    print(f"Testing on seq_len: {seq_len_test}")
    print(f"Output shape: {Q_rot.shape}")
    print("RoPE successfully extrapolated to 2x length!")

test_length_extrapolation()

Length Generalization:

Because RoPE uses trigonometric functions with continuous frequencies, it can naturally extrapolate to any sequence length. The same rotation patterns continue smoothly beyond the training length.

This is why LLaMA trained on 2048 tokens can handle 4096+ tokens reasonably well.

RoPE Variants

1. Partial RoPE

Apply RoPE to only part of the dimensions:

python
class PartialRoPE(torch.nn.Module):
    """Apply RoPE to only a fraction of dimensions."""

    def __init__(self, dim, rope_fraction=0.5, max_seq_len=2048):
        super().__init__()
        rope_dim = int(dim * rope_fraction)
        self.rope_dim = rope_dim
        self.rope = RotaryPositionEmbedding(rope_dim, max_seq_len)

    def forward(self, q, k):
        # Split into RoPE and non-RoPE parts
        q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:]
        k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:]

        # Apply RoPE to first part
        q_rope, k_rope = self.rope(q_rope, k_rope)

        # Concatenate
        q_out = torch.cat([q_rope, q_pass], dim=-1)
        k_out = torch.cat([k_rope, k_pass], dim=-1)

        return q_out, k_out

2. Dynamic RoPE Scaling

Scale base frequency for longer contexts:

python
def scaled_rope(dim, max_seq_len, scale_factor=1.0, base=10000):
    """
    RoPE with scaled frequencies for longer contexts.

    Args:
        scale_factor: Multiply base by this (>1 for longer contexts)
    """
    scaled_base = base * scale_factor
    return RotaryPositionEmbedding(dim, max_seq_len, base=scaled_base)


# Example: 2x longer context with scaled RoPE
rope_standard = RotaryPositionEmbedding(64, max_seq_len=2048)
rope_scaled = scaled_rope(64, max_seq_len=4096, scale_factor=2.0)

print("Standard RoPE: trained on 2048, struggles at 4096")
print("Scaled RoPE: uses scale_factor=2.0 for 4096 tokens")

Models Using RoPE

python
models_with_rope = {
    'LLaMA': 'All sizes (7B-65B)',
    'LLaMA 2': 'All sizes (7B-70B)',
    'GPT-Neo': '1.3B, 2.7B',
    'GPT-J': '6B',
    'PaLM': 'All sizes (8B-540B)',
    'CodeGen': 'All sizes',
    'Mistral': '7B',
    'Mixtral': '8x7B',
}

print("Models using Rotary Position Embeddings:\n")
for model, sizes in models_with_rope.items():
    print(f"  {model:15s}: {sizes}")

Summary

Rotary Position Embeddings encode position through rotation:

Key Advantages:

  1. Relative position: Naturally encodes relative distances in attention
  2. Length extrapolation: Generalizes to longer sequences than training
  3. Parameter-free: No learned parameters, purely algorithmic
  4. Norm-preserving: Rotation doesn't change vector magnitudes
  5. Efficient: Can be precomputed and cached

How it Works:

  • Rotate query and key by angle proportional to position
  • Attention score depends on relative angle (relative position)
  • Uses different frequencies for different dimension pairs

When to Use:

  • Default choice for modern decoder-only models
  • Especially good for models that need length generalization
  • Works best with causal (autoregressive) attention

RoPE has become the standard position encoding for modern LLMs, replacing both absolute position embeddings and learned embeddings in most new architectures.