Back
advanced
Advanced Transformer Concepts

Multi-Head Attention: A Deep Dive

Explore the mathematical foundations and implementation details of multi-head attention, the core mechanism that powers modern transformer models.

25 min read· Attention· Transformer· Deep Learning· PyTorch

Multi-Head Attention: A Deep Dive

Multi-head attention is the beating heart of transformer architectures. While you've seen self-attention before, this lesson dives deep into the mathematical foundations and implementation details that make multi-head attention so powerful.

Why Multiple Heads?

Single-head attention computes one set of attention weights. Multi-head attention computes multiple sets in parallel, allowing the model to attend to different aspects of the input simultaneously.

The Intuition

Consider the sentence: "The bank by the river was flooded."

Different attention heads might focus on:

  • Head 1 (Syntax): "bank" → "was" (subject-verb relationship)
  • Head 2 (Semantics): "bank" → "river" (physical location meaning)
  • Head 3 (Context): "flooded" → "bank", "river" (water-related context)

Each head learns to attend to different linguistic phenomena.

Multiple Perspectives:

Multi-head attention is like having multiple experts analyzing the same text. One expert focuses on grammar, another on meaning, another on relationships. The final output combines all their insights.

Mathematical Foundations

Single-Head Attention Recap

The scaled dot-product attention formula:

Attention(Q, K, V) = softmax(QK^T / √d_k) V

Where:

  • Q (Query): What we're looking for (seq_len × d_k)
  • K (Key): What each position contains (seq_len × d_k)
  • V (Value): The actual information (seq_len × d_v)

Multi-Head Attention Formula

Instead of one attention function with d_model dimensions, we use h parallel attention layers with d_k dimensions each:

MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O

where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)

Projection matrices:

  • W_i^Q ∈ ℝ^(d_model × d_k)
  • W_i^K ∈ ℝ^(d_model × d_k)
  • W_i^V ∈ ℝ^(d_model × d_v)
  • W^O ∈ ℝ^(h·d_v × d_model)

Typically: d_k = d_v = d_model / h

Dimension Consistency:

The key constraint is that h × d_k = d_model. For example, with d_model=512 and h=8 heads, each head gets d_k=64 dimensions. This maintains the same total parameter count as single-head attention with full dimensions.

Implementation from Scratch

Step 1: Scaled Dot-Product Attention

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.

    Args:
        Q: Queries (batch, heads, seq_len, d_k)
        K: Keys (batch, heads, seq_len, d_k)
        V: Values (batch, heads, seq_len, d_v)
        mask: Optional mask (batch, 1, seq_len, seq_len)

    Returns:
        output: Weighted values (batch, heads, seq_len, d_v)
        attention_weights: Attention distribution (batch, heads, seq_len, seq_len)
    """
    d_k = Q.size(-1)

    # Compute attention scores: Q·K^T / √d_k
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    # Shape: (batch, heads, seq_len_q, seq_len_k)

    # Apply mask (if provided)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Apply softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    # Shape: (batch, heads, seq_len_q, seq_len_k)

    # Apply attention to values
    output = torch.matmul(attention_weights, V)
    # Shape: (batch, heads, seq_len_q, d_v)

    return output, attention_weights


# Test the function
batch_size, num_heads, seq_len, d_k = 2, 8, 10, 64
Q = torch.randn(batch_size, num_heads, seq_len, d_k)
K = torch.randn(batch_size, num_heads, seq_len, d_k)
V = torch.randn(batch_size, num_heads, seq_len, d_k)

output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}")  # (2, 8, 10, 64)
print(f"Attention weights shape: {weights.shape}")  # (2, 8, 10, 10)
print(f"Weights sum to 1: {weights[0, 0, 0].sum():.4f}")  # Should be ~1.0

Step 2: Multi-Head Attention Module

python
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention mechanism with complete implementation.
    """

    def __init__(self, d_model, num_heads, dropout=0.1):
        """
        Args:
            d_model: Model dimension (e.g., 512)
            num_heads: Number of attention heads (e.g., 8)
            dropout: Dropout rate for attention weights
        """
        super(MultiHeadAttention, self).__init__()

        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Dimension per head

        # Linear projections for Q, K, V (all heads combined)
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

        # Output projection
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # For visualization/analysis
        self.attention_weights = None

    def split_heads(self, x):
        """
        Split the last dimension into (num_heads, d_k).

        Args:
            x: (batch_size, seq_len, d_model)

        Returns:
            x: (batch_size, num_heads, seq_len, d_k)
        """
        batch_size, seq_len, d_model = x.size()
        # Reshape: (batch, seq_len, num_heads, d_k)
        x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
        # Transpose: (batch, num_heads, seq_len, d_k)
        return x.transpose(1, 2)

    def combine_heads(self, x):
        """
        Inverse of split_heads: combine heads back into d_model.

        Args:
            x: (batch_size, num_heads, seq_len, d_k)

        Returns:
            x: (batch_size, seq_len, d_model)
        """
        batch_size, num_heads, seq_len, d_k = x.size()
        # Transpose: (batch, seq_len, num_heads, d_k)
        x = x.transpose(1, 2).contiguous()
        # Reshape: (batch, seq_len, d_model)
        return x.view(batch_size, seq_len, self.d_model)

    def forward(self, query, key, value, mask=None):
        """
        Forward pass of multi-head attention.

        Args:
            query: (batch_size, seq_len_q, d_model)
            key: (batch_size, seq_len_k, d_model)
            value: (batch_size, seq_len_v, d_model)
            mask: (batch_size, 1, seq_len_q, seq_len_k) or broadcastable

        Returns:
            output: (batch_size, seq_len_q, d_model)
        """
        batch_size = query.size(0)

        # 1. Linear projections
        Q = self.W_q(query)  # (batch, seq_len_q, d_model)
        K = self.W_k(key)    # (batch, seq_len_k, d_model)
        V = self.W_v(value)  # (batch, seq_len_v, d_model)

        # 2. Split into multiple heads
        Q = self.split_heads(Q)  # (batch, num_heads, seq_len_q, d_k)
        K = self.split_heads(K)  # (batch, num_heads, seq_len_k, d_k)
        V = self.split_heads(V)  # (batch, num_heads, seq_len_v, d_k)

        # 3. Apply scaled dot-product attention
        attn_output, attention_weights = scaled_dot_product_attention(
            Q, K, V, mask
        )
        # attn_output: (batch, num_heads, seq_len_q, d_k)
        # attention_weights: (batch, num_heads, seq_len_q, seq_len_k)

        # Store attention weights for visualization
        self.attention_weights = attention_weights.detach()

        # 4. Apply dropout to attention weights
        attn_output = self.dropout(attn_output)

        # 5. Combine heads
        attn_output = self.combine_heads(attn_output)
        # Shape: (batch, seq_len_q, d_model)

        # 6. Final linear projection
        output = self.W_o(attn_output)
        # Shape: (batch, seq_len_q, d_model)

        return output


# Example usage
d_model = 512
num_heads = 8
batch_size = 2
seq_len = 10

mha = MultiHeadAttention(d_model, num_heads)

# Self-attention: Q = K = V
x = torch.randn(batch_size, seq_len, d_model)
output = mha(x, x, x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of parameters: {sum(p.numel() for p in mha.parameters()):,}")

Step 3: Visualizing Attention Patterns

python
import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(attention_weights, tokens, head_idx=0):
    """
    Visualize attention weights for a specific head.

    Args:
        attention_weights: (num_heads, seq_len, seq_len)
        tokens: List of token strings
        head_idx: Which attention head to visualize
    """
    # Get attention weights for specified head
    weights = attention_weights[head_idx].cpu().numpy()

    # Create heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        weights,
        xticklabels=tokens,
        yticklabels=tokens,
        cmap='viridis',
        annot=True,
        fmt='.2f',
        cbar=True
    )
    plt.title(f'Attention Weights - Head {head_idx}')
    plt.xlabel('Keys (attending to)')
    plt.ylabel('Queries (from)')
    plt.tight_layout()
    plt.show()


# Example: Create attention visualization
tokens = ['The', 'cat', 'sat', 'on', 'the', 'mat']
seq_len = len(tokens)

# Generate sample input
x = torch.randn(1, seq_len, d_model)

# Forward pass
mha = MultiHeadAttention(d_model, num_heads=8)
output = mha(x, x, x)

# Get attention weights from first batch item
attn = mha.attention_weights[0]  # Shape: (num_heads, seq_len, seq_len)

# Visualize different heads
visualize_attention(attn, tokens, head_idx=0)
visualize_attention(attn, tokens, head_idx=1)

Interpreting Attention Patterns:

  • Diagonal patterns: Tokens attending to themselves
  • Vertical stripes: One token receiving high attention from many others (important word)
  • Horizontal stripes: One token attending to many others (gathering context)
  • Block patterns: Phrase-level attention (multi-word expressions)

Advanced Concepts

1. Relative Positional Attention

Standard attention uses absolute positions. Relative position attention computes relationships between positions:

python
class RelativeMultiHeadAttention(nn.Module):
    """Multi-head attention with relative position encodings."""

    def __init__(self, d_model, num_heads, max_relative_position=128):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.max_relative_position = max_relative_position

        # Standard Q, K, V projections
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        # Relative position embeddings
        self.relative_positions_embeddings = nn.Embedding(
            2 * max_relative_position + 1,
            self.d_k
        )

    def get_relative_positions(self, seq_len):
        """Compute relative position matrix."""
        range_vec = torch.arange(seq_len)
        range_mat = range_vec.unsqueeze(0).repeat(seq_len, 1)
        distance_mat = range_mat - range_mat.transpose(0, 1)

        # Clip to max relative position
        distance_mat_clipped = torch.clamp(
            distance_mat,
            -self.max_relative_position,
            self.max_relative_position
        )

        # Shift to positive indices
        final_mat = distance_mat_clipped + self.max_relative_position
        return final_mat

    def forward(self, query, key, value, mask=None):
        batch_size, seq_len, _ = query.size()

        # Get relative position embeddings
        rel_pos_indices = self.get_relative_positions(seq_len).to(query.device)
        rel_pos_embeddings = self.relative_positions_embeddings(rel_pos_indices)

        # Standard attention computation with relative positions
        # (Implementation details omitted for brevity)
        # This would modify the attention score computation

        return output

2. Grouped Query Attention (GQA)

Used in modern models like LLaMA 2 to reduce KV cache size:

python
class GroupedQueryAttention(nn.Module):
    """
    Grouped Query Attention: Multiple query heads share K/V heads.
    Used in LLaMA 2 and other modern models for efficiency.
    """

    def __init__(self, d_model, num_query_heads, num_kv_heads, dropout=0.1):
        super().__init__()

        assert num_query_heads % num_kv_heads == 0

        self.d_model = d_model
        self.num_query_heads = num_query_heads
        self.num_kv_heads = num_kv_heads
        self.num_groups = num_query_heads // num_kv_heads
        self.d_k = d_model // num_query_heads

        # Q has full heads, K/V have fewer heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k, bias=False)
        self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()

        # Project Q, K, V
        Q = self.W_q(x).view(batch_size, seq_len, self.num_query_heads, self.d_k)
        K = self.W_k(x).view(batch_size, seq_len, self.num_kv_heads, self.d_k)
        V = self.W_v(x).view(batch_size, seq_len, self.num_kv_heads, self.d_k)

        # Repeat K and V to match number of query heads
        K = K.repeat_interleave(self.num_groups, dim=2)
        V = V.repeat_interleave(self.num_groups, dim=2)

        # Transpose for attention: (batch, heads, seq_len, d_k)
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # Apply attention
        output, _ = scaled_dot_product_attention(Q, K, V, mask)

        # Combine heads and project
        output = output.transpose(1, 2).contiguous()
        output = output.view(batch_size, seq_len, self.d_model)
        output = self.W_o(output)

        return output


# Example: LLaMA 2 uses 32 query heads but only 8 KV heads
gqa = GroupedQueryAttention(d_model=4096, num_query_heads=32, num_kv_heads=8)
x = torch.randn(2, 10, 4096)
output = gqa(x)
print(f"GQA output shape: {output.shape}")

Why Grouped Query Attention?

GQA reduces memory usage by 4x in this example (8 KV heads vs 32 query heads). During inference with large batch sizes or long contexts, the KV cache becomes a major memory bottleneck. GQA maintains quality while dramatically reducing memory requirements.

3. Flash Attention Integration

Modern efficient attention computation:

python
def efficient_attention(Q, K, V, mask=None, use_flash=True):
    """
    Efficient attention using Flash Attention when available.
    Falls back to standard implementation otherwise.
    """
    if use_flash:
        try:
            from flash_attn import flash_attn_func
            # Flash attention requires specific shape and dtype
            # (batch, seq_len, num_heads, d_k)
            output = flash_attn_func(Q, K, V, causal=mask is not None)
            return output, None
        except ImportError:
            pass

    # Standard attention fallback
    return scaled_dot_product_attention(Q, K, V, mask)

Key Insights

Why Multi-Head Works

1. Representation Subspaces: Each head learns different aspects in a lower-dimensional subspace (d_k < d_model).

2. Ensemble Effect: Multiple heads provide multiple "views" that are combined, similar to ensemble learning.

3. Computational Efficiency: h heads of dimension d_k = d_model/h have the same cost as 1 head of dimension d_model, but provide richer representations.

Parameter Count Analysis

python
def count_mha_parameters(d_model, num_heads):
    """Calculate parameter count for multi-head attention."""
    # W_q, W_k, W_v: each is d_model × d_model
    qkv_params = 3 * (d_model * d_model)

    # W_o: d_model × d_model
    output_params = d_model * d_model

    total = qkv_params + output_params
    return total

# GPT-3 small configuration
params = count_mha_parameters(d_model=768, num_heads=12)
print(f"MHA parameters (d=768, h=12): {params:,}")  # 2,359,296

# GPT-3 large configuration
params = count_mha_parameters(d_model=12288, num_heads=96)
print(f"MHA parameters (d=12288, h=96): {params:,}")  # 603,979,776

Summary

Multi-head attention is the core innovation enabling transformer models:

  1. Parallel Attention: Multiple heads attend to different representation subspaces simultaneously
  2. Rich Representations: Each head can specialize in different patterns (syntax, semantics, position)
  3. Efficient Architecture: Same computational cost as single-head with full dimensions
  4. Scalable Design: Works across model sizes from 6 layers to 96+ layers

Modern variants (Grouped Query Attention, Flash Attention) build on this foundation to improve efficiency while maintaining the core multi-head design.