LLaMA Architecture Deep Dive
LLaMA (Large Language Model Meta AI) introduced several architectural improvements over the original transformer that make it more efficient and performant. Let's build LLaMA from scratch to understand each innovation.
Architecture Overview
LLaMA is a decoder-only transformer with four key innovations:
- RMSNorm instead of LayerNorm
- SwiGLU activation instead of ReLU
- Rotary Positional Embeddings (RoPE) instead of absolute positions
- Grouped-Query Attention (GQA) in LLaMA 2
Why These Changes?
Each modification addresses a specific limitation:
- RMSNorm: Faster than LayerNorm, simpler computation
- SwiGLU: Better gradient flow, improved performance
- RoPE: Relative position encoding, better length generalization
- GQA: Reduced memory for inference, maintains quality
1. RMSNorm (Root Mean Square Normalization)
RMSNorm simplifies LayerNorm by removing mean centering.
Mathematical Foundation
LayerNorm:
LN(x) = γ × (x - μ) / √(σ² + ε) + β
RMSNorm:
RMS(x) = γ × x / √(mean(x²) + ε)
No mean subtraction, no bias parameter β.
Implementation
import torch
import torch.nn as nn
import math
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization.
Used in LLaMA for efficiency over standard LayerNorm.
"""
def __init__(self, dim, eps=1e-6):
"""
Args:
dim: Model dimension
eps: Small constant for numerical stability
"""
super().__init__()
self.eps = eps
# Only weight parameter (no bias)
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""
Apply RMS normalization.
Args:
x: (batch, seq_len, dim)
"""
# RMS = sqrt(mean(x^2))
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x / rms
def forward(self, x):
"""
Args:
x: (batch, seq_len, dim)
Returns:
Normalized and scaled tensor
"""
# Normalize and scale
output = self._norm(x.float()).type_as(x)
return output * self.weight
# Test RMSNorm
batch, seq_len, dim = 2, 10, 512
x = torch.randn(batch, seq_len, dim) * 10 # Large variance
rms_norm = RMSNorm(dim)
output = rms_norm(x)
print(f"Input mean: {x.mean():.4f}, std: {x.std():.4f}")
print(f"Output mean: {output.mean():.4f}, std: {output.std():.4f}")
print(f"RMS of output: {torch.sqrt(torch.mean(output ** 2)):.4f}")
# Compare speed with LayerNorm
layer_norm = nn.LayerNorm(dim)
import time
x = torch.randn(100, 2048, 4096).cuda()
rms_norm = RMSNorm(4096).cuda()
layer_norm = nn.LayerNorm(4096).cuda()
# Benchmark
iterations = 100
start = time.time()
for _ in range(iterations):
_ = layer_norm(x)
torch.cuda.synchronize()
ln_time = time.time() - start
start = time.time()
for _ in range(iterations):
_ = rms_norm(x)
torch.cuda.synchronize()
rms_time = time.time() - start
print(f"\nLayerNorm time: {ln_time:.4f}s")
print(f"RMSNorm time: {rms_time:.4f}s")
print(f"Speedup: {ln_time/rms_time:.2f}x")
Why RMSNorm is Faster:
By removing mean computation and the bias parameter:
- Fewer operations (no mean subtraction)
- Less memory bandwidth (no bias to load)
- Simpler gradient computation
Typically 10-20% faster than LayerNorm.
2. SwiGLU Activation
SwiGLU combines the Swish activation with a gated linear unit (GLU).
The Formula
SwiGLU(x, W, V, b, c) = Swish(xW + b) ⊗ (xV + c)
where:
- Swish(x) = x × σ(x) = x × sigmoid(x)
- ⊗ is element-wise multiplication
Implementation
class SwiGLU(nn.Module):
"""
SwiGLU activation function used in LLaMA.
Combines Swish activation with Gated Linear Unit.
"""
def __init__(self, dim, hidden_dim=None, bias=False):
"""
Args:
dim: Input dimension
hidden_dim: Hidden dimension (default: 4 * dim for LLaMA)
bias: Whether to use bias (LLaMA uses False)
"""
super().__init__()
if hidden_dim is None:
hidden_dim = 4 * dim
# SwiGLU requires 2 separate linear projections
# W and V in the formula above
self.w = nn.Linear(dim, hidden_dim, bias=bias)
self.v = nn.Linear(dim, hidden_dim, bias=bias)
self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
def forward(self, x):
"""
Args:
x: (batch, seq_len, dim)
Returns:
(batch, seq_len, dim)
"""
# SwiGLU(x) = Swish(xW) ⊗ (xV)
# Swish(x) = x * sigmoid(x)
swish_out = self.w(x) * torch.sigmoid(self.w(x))
gated = swish_out * self.v(x)
return self.w2(gated)
# Alternative: Combined implementation (more efficient)
class SwiGLUEfficient(nn.Module):
"""
More efficient SwiGLU using a single linear layer.
"""
def __init__(self, dim, hidden_dim=None, bias=False):
super().__init__()
if hidden_dim is None:
hidden_dim = 4 * dim
# Combine W and V into single projection (2 * hidden_dim output)
self.w = nn.Linear(dim, 2 * hidden_dim, bias=bias)
self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
def forward(self, x):
"""
Args:
x: (batch, seq_len, dim)
"""
# Project to 2 * hidden_dim then split
x_proj = self.w(x)
x1, x2 = x_proj.chunk(2, dim=-1)
# SwiGLU: Swish(x1) * x2
swish = x1 * torch.sigmoid(x1)
gated = swish * x2
return self.w2(gated)
# Test SwiGLU
x = torch.randn(2, 10, 512)
swiglu = SwiGLUEfficient(dim=512, hidden_dim=2048)
output = swiglu(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
# Compare with standard FFN
class StandardFFN(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
def forward(self, x):
return self.fc2(torch.relu(self.fc1(x)))
standard_ffn = StandardFFN(512, 2048)
# Visualize activation patterns
x_test = torch.linspace(-5, 5, 1000)
relu = torch.relu(x_test)
swish = x_test * torch.sigmoid(x_test)
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
plt.plot(x_test.numpy(), relu.numpy(), label='ReLU', linewidth=2)
plt.plot(x_test.numpy(), swish.numpy(), label='Swish (SiLU)', linewidth=2)
plt.xlabel('x')
plt.ylabel('Activation')
plt.title('ReLU vs Swish Activation')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Why SwiGLU?
Empirically, SwiGLU outperforms ReLU and GELU in language models:
- Smooth gradients: Unlike ReLU, differentiable everywhere
- Gating mechanism: Adaptive feature selection
- Better performance: Consistently improves perplexity
LLaMA uses hidden_dim = (8/3) * dim to match parameter count with standard 4x expansion.
3. Rotary Position Embeddings (RoPE)
RoPE encodes relative position information directly into attention through rotation.
Mathematical Intuition
Instead of adding position embeddings, RoPE rotates query and key vectors:
q_m = R_m × q (rotate query by position m)
k_n = R_n × k (rotate key by position n)
Then: q_m^T × k_n captures relative position (m - n)
Implementation
class RotaryEmbedding(nn.Module):
"""
Rotary Position Embeddings (RoPE).
Encodes position information through rotation matrices.
"""
def __init__(self, dim, max_seq_len=2048, base=10000):
"""
Args:
dim: Dimension per attention head (d_k)
max_seq_len: Maximum sequence length
base: Base for frequency computation (10000 in paper)
"""
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# Precompute frequency bands
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# Precompute rotation matrices for all positions
self._set_cos_sin_cache(max_seq_len)
def _set_cos_sin_cache(self, seq_len):
"""Precompute cos and sin values for all positions."""
self.max_seq_len_cached = seq_len
# Position indices
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
# Compute frequencies for each position
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Shape: (seq_len, dim // 2)
# Combine frequencies for complex number representation
emb = torch.cat([freqs, freqs], dim=-1)
# Shape: (seq_len, dim)
# Precompute cos and sin
self.register_buffer('cos_cached', emb.cos()[None, :, None, :])
self.register_buffer('sin_cached', emb.sin()[None, :, None, :])
def rotate_half(self, x):
"""
Rotate half the hidden dims of the input.
This creates the complex number rotation effect.
"""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def forward(self, q, k, seq_len=None):
"""
Apply rotary embeddings to queries and keys.
Args:
q: Query tensor (batch, heads, seq_len, dim)
k: Key tensor (batch, heads, seq_len, dim)
seq_len: Sequence length (if None, use q.shape[2])
Returns:
Rotated (q, k)
"""
if seq_len is None:
seq_len = q.shape[2]
# Extend cache if needed
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len)
# Get cached cos/sin for this sequence length
cos = self.cos_cached[:, :seq_len, :, :]
sin = self.sin_cached[:, :seq_len, :, :]
# Apply rotation
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed
# Test RoPE
batch_size, num_heads, seq_len, head_dim = 2, 8, 10, 64
q = torch.randn(batch_size, num_heads, seq_len, head_dim)
k = torch.randn(batch_size, num_heads, seq_len, head_dim)
rope = RotaryEmbedding(dim=head_dim)
q_rot, k_rot = rope(q, k)
print(f"Original Q shape: {q.shape}")
print(f"Rotated Q shape: {q_rot.shape}")
# Verify rotation preserves norm
q_norm_before = torch.norm(q, dim=-1).mean()
q_norm_after = torch.norm(q_rot, dim=-1).mean()
print(f"\nNorm before RoPE: {q_norm_before:.4f}")
print(f"Norm after RoPE: {q_norm_after:.4f}")
print("Rotation preserves norm!" if abs(q_norm_before - q_norm_after) < 0.01 else "Norm changed!")
RoPE Advantages:
- Relative positions: Attention naturally captures relative position
- Length generalization: Can extrapolate to longer sequences
- No learned parameters: Purely algorithmic
- Preserves norms: Rotation doesn't change vector magnitude
Used in: LLaMA, GPT-Neo, GPT-J, PaLM
4. Complete LLaMA Architecture
Now let's combine everything into a full LLaMA model:
class LLaMAAttention(nn.Module):
"""Multi-head attention with RoPE and optional GQA."""
def __init__(self, dim, num_heads, num_kv_heads=None, max_seq_len=2048):
"""
Args:
dim: Model dimension
num_heads: Number of query heads
num_kv_heads: Number of KV heads (for GQA). If None, use num_heads (MHA)
"""
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_groups = num_heads // self.num_kv_heads
self.head_dim = dim // num_heads
# Q, K, V projections
self.q_proj = nn.Linear(dim, num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(num_heads * self.head_dim, dim, bias=False)
# RoPE
self.rope = RotaryEmbedding(self.head_dim, max_seq_len=max_seq_len)
def forward(self, x, mask=None):
"""
Args:
x: (batch, seq_len, dim)
mask: Attention mask
Returns:
(batch, seq_len, dim)
"""
batch_size, seq_len, _ = x.shape
# Project Q, K, V
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# Transpose to (batch, heads, seq_len, head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Apply RoPE
q, k = self.rope(q, k, seq_len=seq_len)
# Expand K, V for grouped-query attention
if self.num_groups > 1:
k = k.repeat_interleave(self.num_groups, dim=1)
v = v.repeat_interleave(self.num_groups, dim=1)
# Scaled dot-product attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, v)
# Reshape and project
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, -1)
output = self.o_proj(output)
return output
class LLaMABlock(nn.Module):
"""Single LLaMA transformer block."""
def __init__(self, dim, num_heads, num_kv_heads=None, hidden_dim=None):
"""
Args:
dim: Model dimension
num_heads: Number of attention heads
num_kv_heads: Number of KV heads (for GQA)
hidden_dim: FFN hidden dimension (default: 4 * dim)
"""
super().__init__()
# Pre-normalization
self.attention_norm = RMSNorm(dim)
self.ffn_norm = RMSNorm(dim)
# Attention
self.attention = LLaMAAttention(dim, num_heads, num_kv_heads)
# Feed-forward with SwiGLU
if hidden_dim is None:
# LLaMA uses 8/3 * dim for parameter efficiency
hidden_dim = int(8 * dim / 3)
# Round to nearest multiple of 256 for efficiency
hidden_dim = 256 * ((hidden_dim + 255) // 256)
self.feed_forward = SwiGLUEfficient(dim, hidden_dim)
def forward(self, x, mask=None):
"""
Args:
x: (batch, seq_len, dim)
mask: Attention mask
Returns:
(batch, seq_len, dim)
"""
# Pre-norm attention with residual
h = x + self.attention(self.attention_norm(x), mask)
# Pre-norm FFN with residual
out = h + self.feed_forward(self.ffn_norm(h))
return out
class LLaMA(nn.Module):
"""Complete LLaMA model."""
def __init__(
self,
vocab_size,
dim,
num_layers,
num_heads,
num_kv_heads=None,
max_seq_len=2048
):
"""
Args:
vocab_size: Vocabulary size
dim: Model dimension
num_layers: Number of transformer layers
num_heads: Number of attention heads
num_kv_heads: Number of KV heads (for GQA in LLaMA 2)
max_seq_len: Maximum sequence length
"""
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
self.num_layers = num_layers
# Token embeddings
self.tok_embeddings = nn.Embedding(vocab_size, dim)
# Transformer layers
self.layers = nn.ModuleList([
LLaMABlock(dim, num_heads, num_kv_heads)
for _ in range(num_layers)
])
# Output normalization and projection
self.norm = RMSNorm(dim)
self.output = nn.Linear(dim, vocab_size, bias=False)
# Tie embeddings (weight sharing)
self.output.weight = self.tok_embeddings.weight
def forward(self, tokens, mask=None):
"""
Args:
tokens: (batch, seq_len) token indices
mask: Attention mask
Returns:
(batch, seq_len, vocab_size) logits
"""
# Embed tokens
h = self.tok_embeddings(tokens)
# Apply transformer layers
for layer in self.layers:
h = layer(h, mask)
# Final norm and projection
h = self.norm(h)
logits = self.output(h)
return logits
# Example: LLaMA-7B configuration
llama_7b = LLaMA(
vocab_size=32000,
dim=4096,
num_layers=32,
num_heads=32,
num_kv_heads=32, # LLaMA 1: same as num_heads (MHA)
max_seq_len=2048
)
# Example: LLaMA 2-7B with GQA
llama2_7b = LLaMA(
vocab_size=32000,
dim=4096,
num_layers=32,
num_heads=32,
num_kv_heads=8, # LLaMA 2: GQA with 8 KV heads
max_seq_len=4096 # LLaMA 2: longer context
)
# Count parameters
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"LLaMA 7B parameters: {count_parameters(llama_7b) / 1e9:.2f}B")
print(f"LLaMA 2 7B parameters: {count_parameters(llama2_7b) / 1e9:.2f}B")
# Test forward pass
tokens = torch.randint(0, 32000, (2, 128)) # (batch=2, seq_len=128)
logits = llama2_7b(tokens)
print(f"\nInput shape: {tokens.shape}")
print(f"Output shape: {logits.shape}")
LLaMA Model Sizes
# LLaMA model family configurations
llama_configs = {
'7B': {'dim': 4096, 'num_layers': 32, 'num_heads': 32},
'13B': {'dim': 5120, 'num_layers': 40, 'num_heads': 40},
'33B': {'dim': 6656, 'num_layers': 60, 'num_heads': 52},
'65B': {'dim': 8192, 'num_layers': 80, 'num_heads': 64},
}
for name, config in llama_configs.items():
model = LLaMA(vocab_size=32000, **config)
params = count_parameters(model) / 1e9
print(f"LLaMA {name}: {params:.1f}B parameters")
Summary
LLaMA's architectural innovations:
- RMSNorm: Faster normalization without mean centering
- SwiGLU: Improved activation with gating mechanism
- RoPE: Relative position encoding through rotation
- GQA (LLaMA 2): Memory-efficient attention with shared KV heads
These modifications make LLaMA more efficient and performant than GPT-style transformers, enabling strong performance at smaller scales.