Multi-Head Attention: A Deep Dive
Multi-head attention is the beating heart of transformer architectures. While you've seen self-attention before, this lesson dives deep into the mathematical foundations and implementation details that make multi-head attention so powerful.
Why Multiple Heads?
Single-head attention computes one set of attention weights. Multi-head attention computes multiple sets in parallel, allowing the model to attend to different aspects of the input simultaneously.
The Intuition
Consider the sentence: "The bank by the river was flooded."
Different attention heads might focus on:
- Head 1 (Syntax): "bank" → "was" (subject-verb relationship)
- Head 2 (Semantics): "bank" → "river" (physical location meaning)
- Head 3 (Context): "flooded" → "bank", "river" (water-related context)
Each head learns to attend to different linguistic phenomena.
Multiple Perspectives:
Multi-head attention is like having multiple experts analyzing the same text. One expert focuses on grammar, another on meaning, another on relationships. The final output combines all their insights.
Mathematical Foundations
Single-Head Attention Recap
The scaled dot-product attention formula:
Attention(Q, K, V) = softmax(QK^T / √d_k) V
Where:
- Q (Query): What we're looking for (seq_len × d_k)
- K (Key): What each position contains (seq_len × d_k)
- V (Value): The actual information (seq_len × d_v)
Multi-Head Attention Formula
Instead of one attention function with d_model dimensions, we use h parallel attention layers with d_k dimensions each:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
Projection matrices:
- W_i^Q ∈ ℝ^(d_model × d_k)
- W_i^K ∈ ℝ^(d_model × d_k)
- W_i^V ∈ ℝ^(d_model × d_v)
- W^O ∈ ℝ^(h·d_v × d_model)
Typically: d_k = d_v = d_model / h
Dimension Consistency:
The key constraint is that h × d_k = d_model. For example, with d_model=512 and h=8 heads, each head gets d_k=64 dimensions. This maintains the same total parameter count as single-head attention with full dimensions.
Implementation from Scratch
Step 1: Scaled Dot-Product Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Compute scaled dot-product attention.
Args:
Q: Queries (batch, heads, seq_len, d_k)
K: Keys (batch, heads, seq_len, d_k)
V: Values (batch, heads, seq_len, d_v)
mask: Optional mask (batch, 1, seq_len, seq_len)
Returns:
output: Weighted values (batch, heads, seq_len, d_v)
attention_weights: Attention distribution (batch, heads, seq_len, seq_len)
"""
d_k = Q.size(-1)
# Compute attention scores: Q·K^T / √d_k
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Shape: (batch, heads, seq_len_q, seq_len_k)
# Apply mask (if provided)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Apply softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1)
# Shape: (batch, heads, seq_len_q, seq_len_k)
# Apply attention to values
output = torch.matmul(attention_weights, V)
# Shape: (batch, heads, seq_len_q, d_v)
return output, attention_weights
# Test the function
batch_size, num_heads, seq_len, d_k = 2, 8, 10, 64
Q = torch.randn(batch_size, num_heads, seq_len, d_k)
K = torch.randn(batch_size, num_heads, seq_len, d_k)
V = torch.randn(batch_size, num_heads, seq_len, d_k)
output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}") # (2, 8, 10, 64)
print(f"Attention weights shape: {weights.shape}") # (2, 8, 10, 10)
print(f"Weights sum to 1: {weights[0, 0, 0].sum():.4f}") # Should be ~1.0
Step 2: Multi-Head Attention Module
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention mechanism with complete implementation.
"""
def __init__(self, d_model, num_heads, dropout=0.1):
"""
Args:
d_model: Model dimension (e.g., 512)
num_heads: Number of attention heads (e.g., 8)
dropout: Dropout rate for attention weights
"""
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # Dimension per head
# Linear projections for Q, K, V (all heads combined)
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
# Output projection
self.W_o = nn.Linear(d_model, d_model, bias=False)
# Dropout
self.dropout = nn.Dropout(dropout)
# For visualization/analysis
self.attention_weights = None
def split_heads(self, x):
"""
Split the last dimension into (num_heads, d_k).
Args:
x: (batch_size, seq_len, d_model)
Returns:
x: (batch_size, num_heads, seq_len, d_k)
"""
batch_size, seq_len, d_model = x.size()
# Reshape: (batch, seq_len, num_heads, d_k)
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
# Transpose: (batch, num_heads, seq_len, d_k)
return x.transpose(1, 2)
def combine_heads(self, x):
"""
Inverse of split_heads: combine heads back into d_model.
Args:
x: (batch_size, num_heads, seq_len, d_k)
Returns:
x: (batch_size, seq_len, d_model)
"""
batch_size, num_heads, seq_len, d_k = x.size()
# Transpose: (batch, seq_len, num_heads, d_k)
x = x.transpose(1, 2).contiguous()
# Reshape: (batch, seq_len, d_model)
return x.view(batch_size, seq_len, self.d_model)
def forward(self, query, key, value, mask=None):
"""
Forward pass of multi-head attention.
Args:
query: (batch_size, seq_len_q, d_model)
key: (batch_size, seq_len_k, d_model)
value: (batch_size, seq_len_v, d_model)
mask: (batch_size, 1, seq_len_q, seq_len_k) or broadcastable
Returns:
output: (batch_size, seq_len_q, d_model)
"""
batch_size = query.size(0)
# 1. Linear projections
Q = self.W_q(query) # (batch, seq_len_q, d_model)
K = self.W_k(key) # (batch, seq_len_k, d_model)
V = self.W_v(value) # (batch, seq_len_v, d_model)
# 2. Split into multiple heads
Q = self.split_heads(Q) # (batch, num_heads, seq_len_q, d_k)
K = self.split_heads(K) # (batch, num_heads, seq_len_k, d_k)
V = self.split_heads(V) # (batch, num_heads, seq_len_v, d_k)
# 3. Apply scaled dot-product attention
attn_output, attention_weights = scaled_dot_product_attention(
Q, K, V, mask
)
# attn_output: (batch, num_heads, seq_len_q, d_k)
# attention_weights: (batch, num_heads, seq_len_q, seq_len_k)
# Store attention weights for visualization
self.attention_weights = attention_weights.detach()
# 4. Apply dropout to attention weights
attn_output = self.dropout(attn_output)
# 5. Combine heads
attn_output = self.combine_heads(attn_output)
# Shape: (batch, seq_len_q, d_model)
# 6. Final linear projection
output = self.W_o(attn_output)
# Shape: (batch, seq_len_q, d_model)
return output
# Example usage
d_model = 512
num_heads = 8
batch_size = 2
seq_len = 10
mha = MultiHeadAttention(d_model, num_heads)
# Self-attention: Q = K = V
x = torch.randn(batch_size, seq_len, d_model)
output = mha(x, x, x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of parameters: {sum(p.numel() for p in mha.parameters()):,}")
Step 3: Visualizing Attention Patterns
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(attention_weights, tokens, head_idx=0):
"""
Visualize attention weights for a specific head.
Args:
attention_weights: (num_heads, seq_len, seq_len)
tokens: List of token strings
head_idx: Which attention head to visualize
"""
# Get attention weights for specified head
weights = attention_weights[head_idx].cpu().numpy()
# Create heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(
weights,
xticklabels=tokens,
yticklabels=tokens,
cmap='viridis',
annot=True,
fmt='.2f',
cbar=True
)
plt.title(f'Attention Weights - Head {head_idx}')
plt.xlabel('Keys (attending to)')
plt.ylabel('Queries (from)')
plt.tight_layout()
plt.show()
# Example: Create attention visualization
tokens = ['The', 'cat', 'sat', 'on', 'the', 'mat']
seq_len = len(tokens)
# Generate sample input
x = torch.randn(1, seq_len, d_model)
# Forward pass
mha = MultiHeadAttention(d_model, num_heads=8)
output = mha(x, x, x)
# Get attention weights from first batch item
attn = mha.attention_weights[0] # Shape: (num_heads, seq_len, seq_len)
# Visualize different heads
visualize_attention(attn, tokens, head_idx=0)
visualize_attention(attn, tokens, head_idx=1)
Interpreting Attention Patterns:
- Diagonal patterns: Tokens attending to themselves
- Vertical stripes: One token receiving high attention from many others (important word)
- Horizontal stripes: One token attending to many others (gathering context)
- Block patterns: Phrase-level attention (multi-word expressions)
Advanced Concepts
1. Relative Positional Attention
Standard attention uses absolute positions. Relative position attention computes relationships between positions:
class RelativeMultiHeadAttention(nn.Module):
"""Multi-head attention with relative position encodings."""
def __init__(self, d_model, num_heads, max_relative_position=128):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.max_relative_position = max_relative_position
# Standard Q, K, V projections
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
# Relative position embeddings
self.relative_positions_embeddings = nn.Embedding(
2 * max_relative_position + 1,
self.d_k
)
def get_relative_positions(self, seq_len):
"""Compute relative position matrix."""
range_vec = torch.arange(seq_len)
range_mat = range_vec.unsqueeze(0).repeat(seq_len, 1)
distance_mat = range_mat - range_mat.transpose(0, 1)
# Clip to max relative position
distance_mat_clipped = torch.clamp(
distance_mat,
-self.max_relative_position,
self.max_relative_position
)
# Shift to positive indices
final_mat = distance_mat_clipped + self.max_relative_position
return final_mat
def forward(self, query, key, value, mask=None):
batch_size, seq_len, _ = query.size()
# Get relative position embeddings
rel_pos_indices = self.get_relative_positions(seq_len).to(query.device)
rel_pos_embeddings = self.relative_positions_embeddings(rel_pos_indices)
# Standard attention computation with relative positions
# (Implementation details omitted for brevity)
# This would modify the attention score computation
return output
2. Grouped Query Attention (GQA)
Used in modern models like LLaMA 2 to reduce KV cache size:
class GroupedQueryAttention(nn.Module):
"""
Grouped Query Attention: Multiple query heads share K/V heads.
Used in LLaMA 2 and other modern models for efficiency.
"""
def __init__(self, d_model, num_query_heads, num_kv_heads, dropout=0.1):
super().__init__()
assert num_query_heads % num_kv_heads == 0
self.d_model = d_model
self.num_query_heads = num_query_heads
self.num_kv_heads = num_kv_heads
self.num_groups = num_query_heads // num_kv_heads
self.d_k = d_model // num_query_heads
# Q has full heads, K/V have fewer heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k, bias=False)
self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.size()
# Project Q, K, V
Q = self.W_q(x).view(batch_size, seq_len, self.num_query_heads, self.d_k)
K = self.W_k(x).view(batch_size, seq_len, self.num_kv_heads, self.d_k)
V = self.W_v(x).view(batch_size, seq_len, self.num_kv_heads, self.d_k)
# Repeat K and V to match number of query heads
K = K.repeat_interleave(self.num_groups, dim=2)
V = V.repeat_interleave(self.num_groups, dim=2)
# Transpose for attention: (batch, heads, seq_len, d_k)
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Apply attention
output, _ = scaled_dot_product_attention(Q, K, V, mask)
# Combine heads and project
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, self.d_model)
output = self.W_o(output)
return output
# Example: LLaMA 2 uses 32 query heads but only 8 KV heads
gqa = GroupedQueryAttention(d_model=4096, num_query_heads=32, num_kv_heads=8)
x = torch.randn(2, 10, 4096)
output = gqa(x)
print(f"GQA output shape: {output.shape}")
Why Grouped Query Attention?
GQA reduces memory usage by 4x in this example (8 KV heads vs 32 query heads). During inference with large batch sizes or long contexts, the KV cache becomes a major memory bottleneck. GQA maintains quality while dramatically reducing memory requirements.
3. Flash Attention Integration
Modern efficient attention computation:
def efficient_attention(Q, K, V, mask=None, use_flash=True):
"""
Efficient attention using Flash Attention when available.
Falls back to standard implementation otherwise.
"""
if use_flash:
try:
from flash_attn import flash_attn_func
# Flash attention requires specific shape and dtype
# (batch, seq_len, num_heads, d_k)
output = flash_attn_func(Q, K, V, causal=mask is not None)
return output, None
except ImportError:
pass
# Standard attention fallback
return scaled_dot_product_attention(Q, K, V, mask)
Key Insights
Why Multi-Head Works
1. Representation Subspaces: Each head learns different aspects in a lower-dimensional subspace (d_k < d_model).
2. Ensemble Effect: Multiple heads provide multiple "views" that are combined, similar to ensemble learning.
3. Computational Efficiency: h heads of dimension d_k = d_model/h have the same cost as 1 head of dimension d_model, but provide richer representations.
Parameter Count Analysis
def count_mha_parameters(d_model, num_heads):
"""Calculate parameter count for multi-head attention."""
# W_q, W_k, W_v: each is d_model × d_model
qkv_params = 3 * (d_model * d_model)
# W_o: d_model × d_model
output_params = d_model * d_model
total = qkv_params + output_params
return total
# GPT-3 small configuration
params = count_mha_parameters(d_model=768, num_heads=12)
print(f"MHA parameters (d=768, h=12): {params:,}") # 2,359,296
# GPT-3 large configuration
params = count_mha_parameters(d_model=12288, num_heads=96)
print(f"MHA parameters (d=12288, h=96): {params:,}") # 603,979,776
Summary
Multi-head attention is the core innovation enabling transformer models:
- Parallel Attention: Multiple heads attend to different representation subspaces simultaneously
- Rich Representations: Each head can specialize in different patterns (syntax, semantics, position)
- Efficient Architecture: Same computational cost as single-head with full dimensions
- Scalable Design: Works across model sizes from 6 layers to 96+ layers
Modern variants (Grouped Query Attention, Flash Attention) build on this foundation to improve efficiency while maintaining the core multi-head design.