Back
advanced
Modern Architectures

LLaMA Architecture Deep Dive

Explore the architectural innovations in LLaMA: RMSNorm, SwiGLU activation, rotary positional embeddings (RoPE), and grouped-query attention with complete PyTorch implementation.

25 min read· LLaMA· Architecture· RoPE· RMSNorm

LLaMA Architecture Deep Dive

LLaMA (Large Language Model Meta AI) introduced several architectural improvements over the original transformer that make it more efficient and performant. Let's build LLaMA from scratch to understand each innovation.

Architecture Overview

LLaMA is a decoder-only transformer with four key innovations:

  1. RMSNorm instead of LayerNorm
  2. SwiGLU activation instead of ReLU
  3. Rotary Positional Embeddings (RoPE) instead of absolute positions
  4. Grouped-Query Attention (GQA) in LLaMA 2

Why These Changes?

Each modification addresses a specific limitation:

  • RMSNorm: Faster than LayerNorm, simpler computation
  • SwiGLU: Better gradient flow, improved performance
  • RoPE: Relative position encoding, better length generalization
  • GQA: Reduced memory for inference, maintains quality

1. RMSNorm (Root Mean Square Normalization)

RMSNorm simplifies LayerNorm by removing mean centering.

Mathematical Foundation

LayerNorm:

LN(x) = γ × (x - μ) / √(σ² + ε) + β

RMSNorm:

RMS(x) = γ × x / √(mean(x²) + ε)

No mean subtraction, no bias parameter β.

Implementation

python
import torch
import torch.nn as nn
import math

class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization.

    Used in LLaMA for efficiency over standard LayerNorm.
    """

    def __init__(self, dim, eps=1e-6):
        """
        Args:
            dim: Model dimension
            eps: Small constant for numerical stability
        """
        super().__init__()
        self.eps = eps
        # Only weight parameter (no bias)
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        """
        Apply RMS normalization.

        Args:
            x: (batch, seq_len, dim)
        """
        # RMS = sqrt(mean(x^2))
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms

    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, dim)

        Returns:
            Normalized and scaled tensor
        """
        # Normalize and scale
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


# Test RMSNorm
batch, seq_len, dim = 2, 10, 512
x = torch.randn(batch, seq_len, dim) * 10  # Large variance

rms_norm = RMSNorm(dim)
output = rms_norm(x)

print(f"Input mean: {x.mean():.4f}, std: {x.std():.4f}")
print(f"Output mean: {output.mean():.4f}, std: {output.std():.4f}")
print(f"RMS of output: {torch.sqrt(torch.mean(output ** 2)):.4f}")

# Compare speed with LayerNorm
layer_norm = nn.LayerNorm(dim)

import time

x = torch.randn(100, 2048, 4096).cuda()
rms_norm = RMSNorm(4096).cuda()
layer_norm = nn.LayerNorm(4096).cuda()

# Benchmark
iterations = 100

start = time.time()
for _ in range(iterations):
    _ = layer_norm(x)
torch.cuda.synchronize()
ln_time = time.time() - start

start = time.time()
for _ in range(iterations):
    _ = rms_norm(x)
torch.cuda.synchronize()
rms_time = time.time() - start

print(f"\nLayerNorm time: {ln_time:.4f}s")
print(f"RMSNorm time: {rms_time:.4f}s")
print(f"Speedup: {ln_time/rms_time:.2f}x")

Why RMSNorm is Faster:

By removing mean computation and the bias parameter:

  1. Fewer operations (no mean subtraction)
  2. Less memory bandwidth (no bias to load)
  3. Simpler gradient computation

Typically 10-20% faster than LayerNorm.

2. SwiGLU Activation

SwiGLU combines the Swish activation with a gated linear unit (GLU).

The Formula

SwiGLU(x, W, V, b, c) = Swish(xW + b) ⊗ (xV + c)

where:
- Swish(x) = x × σ(x) = x × sigmoid(x)
- ⊗ is element-wise multiplication

Implementation

python
class SwiGLU(nn.Module):
    """
    SwiGLU activation function used in LLaMA.

    Combines Swish activation with Gated Linear Unit.
    """

    def __init__(self, dim, hidden_dim=None, bias=False):
        """
        Args:
            dim: Input dimension
            hidden_dim: Hidden dimension (default: 4 * dim for LLaMA)
            bias: Whether to use bias (LLaMA uses False)
        """
        super().__init__()

        if hidden_dim is None:
            hidden_dim = 4 * dim

        # SwiGLU requires 2 separate linear projections
        # W and V in the formula above
        self.w = nn.Linear(dim, hidden_dim, bias=bias)
        self.v = nn.Linear(dim, hidden_dim, bias=bias)
        self.w2 = nn.Linear(hidden_dim, dim, bias=bias)

    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, dim)

        Returns:
            (batch, seq_len, dim)
        """
        # SwiGLU(x) = Swish(xW) ⊗ (xV)
        # Swish(x) = x * sigmoid(x)
        swish_out = self.w(x) * torch.sigmoid(self.w(x))
        gated = swish_out * self.v(x)
        return self.w2(gated)


# Alternative: Combined implementation (more efficient)
class SwiGLUEfficient(nn.Module):
    """
    More efficient SwiGLU using a single linear layer.
    """

    def __init__(self, dim, hidden_dim=None, bias=False):
        super().__init__()

        if hidden_dim is None:
            hidden_dim = 4 * dim

        # Combine W and V into single projection (2 * hidden_dim output)
        self.w = nn.Linear(dim, 2 * hidden_dim, bias=bias)
        self.w2 = nn.Linear(hidden_dim, dim, bias=bias)

    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, dim)
        """
        # Project to 2 * hidden_dim then split
        x_proj = self.w(x)
        x1, x2 = x_proj.chunk(2, dim=-1)

        # SwiGLU: Swish(x1) * x2
        swish = x1 * torch.sigmoid(x1)
        gated = swish * x2

        return self.w2(gated)


# Test SwiGLU
x = torch.randn(2, 10, 512)

swiglu = SwiGLUEfficient(dim=512, hidden_dim=2048)
output = swiglu(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

# Compare with standard FFN
class StandardFFN(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        return self.fc2(torch.relu(self.fc1(x)))

standard_ffn = StandardFFN(512, 2048)

# Visualize activation patterns
x_test = torch.linspace(-5, 5, 1000)

relu = torch.relu(x_test)
swish = x_test * torch.sigmoid(x_test)

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.plot(x_test.numpy(), relu.numpy(), label='ReLU', linewidth=2)
plt.plot(x_test.numpy(), swish.numpy(), label='Swish (SiLU)', linewidth=2)
plt.xlabel('x')
plt.ylabel('Activation')
plt.title('ReLU vs Swish Activation')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

Why SwiGLU?

Empirically, SwiGLU outperforms ReLU and GELU in language models:

  1. Smooth gradients: Unlike ReLU, differentiable everywhere
  2. Gating mechanism: Adaptive feature selection
  3. Better performance: Consistently improves perplexity

LLaMA uses hidden_dim = (8/3) * dim to match parameter count with standard 4x expansion.

3. Rotary Position Embeddings (RoPE)

RoPE encodes relative position information directly into attention through rotation.

Mathematical Intuition

Instead of adding position embeddings, RoPE rotates query and key vectors:

q_m = R_m × q    (rotate query by position m)
k_n = R_n × k    (rotate key by position n)

Then: q_m^T × k_n captures relative position (m - n)

Implementation

python
class RotaryEmbedding(nn.Module):
    """
    Rotary Position Embeddings (RoPE).

    Encodes position information through rotation matrices.
    """

    def __init__(self, dim, max_seq_len=2048, base=10000):
        """
        Args:
            dim: Dimension per attention head (d_k)
            max_seq_len: Maximum sequence length
            base: Base for frequency computation (10000 in paper)
        """
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        # Precompute frequency bands
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        # Precompute rotation matrices for all positions
        self._set_cos_sin_cache(max_seq_len)

    def _set_cos_sin_cache(self, seq_len):
        """Precompute cos and sin values for all positions."""
        self.max_seq_len_cached = seq_len

        # Position indices
        t = torch.arange(seq_len, dtype=self.inv_freq.dtype)

        # Compute frequencies for each position
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        # Shape: (seq_len, dim // 2)

        # Combine frequencies for complex number representation
        emb = torch.cat([freqs, freqs], dim=-1)
        # Shape: (seq_len, dim)

        # Precompute cos and sin
        self.register_buffer('cos_cached', emb.cos()[None, :, None, :])
        self.register_buffer('sin_cached', emb.sin()[None, :, None, :])

    def rotate_half(self, x):
        """
        Rotate half the hidden dims of the input.

        This creates the complex number rotation effect.
        """
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat([-x2, x1], dim=-1)

    def forward(self, q, k, seq_len=None):
        """
        Apply rotary embeddings to queries and keys.

        Args:
            q: Query tensor (batch, heads, seq_len, dim)
            k: Key tensor (batch, heads, seq_len, dim)
            seq_len: Sequence length (if None, use q.shape[2])

        Returns:
            Rotated (q, k)
        """
        if seq_len is None:
            seq_len = q.shape[2]

        # Extend cache if needed
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len)

        # Get cached cos/sin for this sequence length
        cos = self.cos_cached[:, :seq_len, :, :]
        sin = self.sin_cached[:, :seq_len, :, :]

        # Apply rotation
        q_embed = (q * cos) + (self.rotate_half(q) * sin)
        k_embed = (k * cos) + (self.rotate_half(k) * sin)

        return q_embed, k_embed


# Test RoPE
batch_size, num_heads, seq_len, head_dim = 2, 8, 10, 64

q = torch.randn(batch_size, num_heads, seq_len, head_dim)
k = torch.randn(batch_size, num_heads, seq_len, head_dim)

rope = RotaryEmbedding(dim=head_dim)
q_rot, k_rot = rope(q, k)

print(f"Original Q shape: {q.shape}")
print(f"Rotated Q shape: {q_rot.shape}")

# Verify rotation preserves norm
q_norm_before = torch.norm(q, dim=-1).mean()
q_norm_after = torch.norm(q_rot, dim=-1).mean()
print(f"\nNorm before RoPE: {q_norm_before:.4f}")
print(f"Norm after RoPE: {q_norm_after:.4f}")
print("Rotation preserves norm!" if abs(q_norm_before - q_norm_after) < 0.01 else "Norm changed!")

RoPE Advantages:

  1. Relative positions: Attention naturally captures relative position
  2. Length generalization: Can extrapolate to longer sequences
  3. No learned parameters: Purely algorithmic
  4. Preserves norms: Rotation doesn't change vector magnitude

Used in: LLaMA, GPT-Neo, GPT-J, PaLM

4. Complete LLaMA Architecture

Now let's combine everything into a full LLaMA model:

python
class LLaMAAttention(nn.Module):
    """Multi-head attention with RoPE and optional GQA."""

    def __init__(self, dim, num_heads, num_kv_heads=None, max_seq_len=2048):
        """
        Args:
            dim: Model dimension
            num_heads: Number of query heads
            num_kv_heads: Number of KV heads (for GQA). If None, use num_heads (MHA)
        """
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
        self.num_groups = num_heads // self.num_kv_heads
        self.head_dim = dim // num_heads

        # Q, K, V projections
        self.q_proj = nn.Linear(dim, num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(num_heads * self.head_dim, dim, bias=False)

        # RoPE
        self.rope = RotaryEmbedding(self.head_dim, max_seq_len=max_seq_len)

    def forward(self, x, mask=None):
        """
        Args:
            x: (batch, seq_len, dim)
            mask: Attention mask

        Returns:
            (batch, seq_len, dim)
        """
        batch_size, seq_len, _ = x.shape

        # Project Q, K, V
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
        v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)

        # Transpose to (batch, heads, seq_len, head_dim)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Apply RoPE
        q, k = self.rope(q, k, seq_len=seq_len)

        # Expand K, V for grouped-query attention
        if self.num_groups > 1:
            k = k.repeat_interleave(self.num_groups, dim=1)
            v = v.repeat_interleave(self.num_groups, dim=1)

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, v)

        # Reshape and project
        output = output.transpose(1, 2).contiguous()
        output = output.view(batch_size, seq_len, -1)
        output = self.o_proj(output)

        return output


class LLaMABlock(nn.Module):
    """Single LLaMA transformer block."""

    def __init__(self, dim, num_heads, num_kv_heads=None, hidden_dim=None):
        """
        Args:
            dim: Model dimension
            num_heads: Number of attention heads
            num_kv_heads: Number of KV heads (for GQA)
            hidden_dim: FFN hidden dimension (default: 4 * dim)
        """
        super().__init__()

        # Pre-normalization
        self.attention_norm = RMSNorm(dim)
        self.ffn_norm = RMSNorm(dim)

        # Attention
        self.attention = LLaMAAttention(dim, num_heads, num_kv_heads)

        # Feed-forward with SwiGLU
        if hidden_dim is None:
            # LLaMA uses 8/3 * dim for parameter efficiency
            hidden_dim = int(8 * dim / 3)
            # Round to nearest multiple of 256 for efficiency
            hidden_dim = 256 * ((hidden_dim + 255) // 256)

        self.feed_forward = SwiGLUEfficient(dim, hidden_dim)

    def forward(self, x, mask=None):
        """
        Args:
            x: (batch, seq_len, dim)
            mask: Attention mask

        Returns:
            (batch, seq_len, dim)
        """
        # Pre-norm attention with residual
        h = x + self.attention(self.attention_norm(x), mask)

        # Pre-norm FFN with residual
        out = h + self.feed_forward(self.ffn_norm(h))

        return out


class LLaMA(nn.Module):
    """Complete LLaMA model."""

    def __init__(
        self,
        vocab_size,
        dim,
        num_layers,
        num_heads,
        num_kv_heads=None,
        max_seq_len=2048
    ):
        """
        Args:
            vocab_size: Vocabulary size
            dim: Model dimension
            num_layers: Number of transformer layers
            num_heads: Number of attention heads
            num_kv_heads: Number of KV heads (for GQA in LLaMA 2)
            max_seq_len: Maximum sequence length
        """
        super().__init__()

        self.vocab_size = vocab_size
        self.dim = dim
        self.num_layers = num_layers

        # Token embeddings
        self.tok_embeddings = nn.Embedding(vocab_size, dim)

        # Transformer layers
        self.layers = nn.ModuleList([
            LLaMABlock(dim, num_heads, num_kv_heads)
            for _ in range(num_layers)
        ])

        # Output normalization and projection
        self.norm = RMSNorm(dim)
        self.output = nn.Linear(dim, vocab_size, bias=False)

        # Tie embeddings (weight sharing)
        self.output.weight = self.tok_embeddings.weight

    def forward(self, tokens, mask=None):
        """
        Args:
            tokens: (batch, seq_len) token indices
            mask: Attention mask

        Returns:
            (batch, seq_len, vocab_size) logits
        """
        # Embed tokens
        h = self.tok_embeddings(tokens)

        # Apply transformer layers
        for layer in self.layers:
            h = layer(h, mask)

        # Final norm and projection
        h = self.norm(h)
        logits = self.output(h)

        return logits


# Example: LLaMA-7B configuration
llama_7b = LLaMA(
    vocab_size=32000,
    dim=4096,
    num_layers=32,
    num_heads=32,
    num_kv_heads=32,  # LLaMA 1: same as num_heads (MHA)
    max_seq_len=2048
)

# Example: LLaMA 2-7B with GQA
llama2_7b = LLaMA(
    vocab_size=32000,
    dim=4096,
    num_layers=32,
    num_heads=32,
    num_kv_heads=8,   # LLaMA 2: GQA with 8 KV heads
    max_seq_len=4096  # LLaMA 2: longer context
)

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"LLaMA 7B parameters: {count_parameters(llama_7b) / 1e9:.2f}B")
print(f"LLaMA 2 7B parameters: {count_parameters(llama2_7b) / 1e9:.2f}B")

# Test forward pass
tokens = torch.randint(0, 32000, (2, 128))  # (batch=2, seq_len=128)
logits = llama2_7b(tokens)
print(f"\nInput shape: {tokens.shape}")
print(f"Output shape: {logits.shape}")

LLaMA Model Sizes

python
# LLaMA model family configurations
llama_configs = {
    '7B': {'dim': 4096, 'num_layers': 32, 'num_heads': 32},
    '13B': {'dim': 5120, 'num_layers': 40, 'num_heads': 40},
    '33B': {'dim': 6656, 'num_layers': 60, 'num_heads': 52},
    '65B': {'dim': 8192, 'num_layers': 80, 'num_heads': 64},
}

for name, config in llama_configs.items():
    model = LLaMA(vocab_size=32000, **config)
    params = count_parameters(model) / 1e9
    print(f"LLaMA {name}: {params:.1f}B parameters")

Summary

LLaMA's architectural innovations:

  1. RMSNorm: Faster normalization without mean centering
  2. SwiGLU: Improved activation with gating mechanism
  3. RoPE: Relative position encoding through rotation
  4. GQA (LLaMA 2): Memory-efficient attention with shared KV heads

These modifications make LLaMA more efficient and performant than GPT-style transformers, enabling strong performance at smaller scales.