Back
intermediate
Foundation of Transformers

Understanding Self-Attention

Deep dive into the self-attention mechanism that powers transformers. Learn how Query, Key, and Value matrices work together to create context-aware representations.

20 min read· Attention· Transformer· Q-K-V· Neural Networks

Understanding Self-Attention

Self-attention is the core innovation that makes transformers powerful. Unlike RNNs that process sequences sequentially, self-attention allows each position to attend to all positions in a single operation. This lesson provides a technical deep-dive into how it works.

The Attention Intuition

Consider translating: "The bank of the river was flooded."

The word "bank" is ambiguous:

  • Financial institution?
  • Side of a river?

Humans use context ("river", "flooded") to disambiguate. Self-attention does the same computationally.

How Self-Attention Resolves Ambiguity

Context weights for "bank":
- "The"    : 0.05
- "bank"   : 0.10
- "of"     : 0.05
- "the"    : 0.03
- "river"  : 0.45  ← High attention!
- "was"    : 0.02
- "flooded": 0.30  ← High attention!

The model learns that "river" and "flooded" provide crucial context for understanding "bank".

The Query-Key-Value Framework

Query-Key-Value (QKV) Framework: A retrieval-inspired mechanism where Query matrices represent "what to look for," Key matrices represent "what is offered," and Value matrices contain "the actual information to retrieve." Attention compares queries with keys to weight values.

Self-attention uses three learned transformations of the input: Query (Q), Key (K), and Value (V).

Think of attention like searching a library:

  1. Query (Q): Your search request ("I need information about transformers")
  2. Keys (K): Book titles/indexes that describe what each book contains
  3. Values (V): The actual content of the books

The attention mechanism:

  • Compares your query against all keys (which books match?)
  • Assigns scores (relevance scores)
  • Retrieves and combines values (actual information) based on scores

Mathematical Formulation

Attention(Q, K, V) = softmax(QK^T / √d_k) V

Let's break this down step by step.

Step-by-Step Attention Computation

Setup: Input Embeddings

python
import torch
import torch.nn as nn
import torch.nn.functional as F

# Example: Simple sentence with 3 words
# "The cat sat"
seq_len = 3
d_model = 4  # Small for illustration

# Random embeddings (in practice, these come from an embedding layer)
X = torch.randn(seq_len, d_model)
print("Input embeddings shape:", X.shape)  # (3, 4)
print("\nInput embeddings:")
print(X)

Step 1: Create Q, K, V Matrices

python
# Learned weight matrices
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)

# Apply transformations
Q = W_q(X)  # Query: "What am I looking for?"
K = W_k(X)  # Key: "What do I contain?"
V = W_v(X)  # Value: "What information do I output?"

print("Q shape:", Q.shape)  # (3, 4)
print("K shape:", K.shape)  # (3, 4)
print("V shape:", V.shape)  # (3, 4)

Why three separate matrices?

Different transformations allow the model to separate:

  • What each position is looking for (Q)
  • What each position offers (K)
  • What each position contains (V)

This separation provides more expressive power than using the same matrix for all three.

Step 2: Compute Attention Scores

Attention Scores: Numerical values computed by taking the dot product of query and key vectors, indicating how much each position should "attend to" (focus on) every other position in the sequence.

python
# Dot product between queries and keys
# Q: (seq_len, d_model) = (3, 4)
# K^T: (d_model, seq_len) = (4, 3)
# Scores: (seq_len, seq_len) = (3, 3)

scores = torch.matmul(Q, K.transpose(-2, -1))
print("Raw scores shape:", scores.shape)  # (3, 3)
print("\nRaw scores:")
print(scores)

The score matrix tells us:

scores[i, j] = similarity between query_i and key_j

Interpretation:

  • scores[0, 0]
    : How much does "The" attend to "The"?
  • scores[0, 1]
    : How much does "The" attend to "cat"?
  • scores[0, 2]
    : How much does "The" attend to "sat"?

Step 3: Scale Scores

python
d_k = Q.size(-1)  # Dimension of queries/keys
scaled_scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

print("Scaled scores:")
print(scaled_scores)

Why scaling matters:

Without scaling, as

d_k
increases, dot products grow large. This pushes softmax into regions where gradients are tiny, making training unstable.

Example with d_k=64:

  • Dot product might be ~25 (large)
  • After softmax: nearly all weight on one position
  • Gradients: ~0 (vanishing)

Scaling by √d_k keeps values reasonable.

Step 4: Apply Softmax

python
attention_weights = F.softmax(scaled_scores, dim=-1)
print("Attention weights (sum to 1 per row):")
print(attention_weights)
print("\nRow sums (should be 1.0):")
print(attention_weights.sum(dim=-1))

Softmax ensures:

  • All weights are positive (0 to 1)
  • Weights for each query sum to 1 (probability distribution)

Step 5: Weighted Sum of Values

python
output = torch.matmul(attention_weights, V)
print("Output shape:", output.shape)  # (3, 4)
print("\nOutput:")
print(output)

Each output position is a weighted combination of all value vectors:

output[0] = attention_weights[0,0] * V[0]
          + attention_weights[0,1] * V[1]
          + attention_weights[0,2] * V[2]

Complete Self-Attention Implementation

python
def self_attention(X, W_q, W_k, W_v, mask=None):
    """
    Complete self-attention mechanism

    Args:
        X: Input embeddings (seq_len, d_model)
        W_q, W_k, W_v: Query, Key, Value weight matrices
        mask: Optional attention mask

    Returns:
        output: Context-aware representations (seq_len, d_model)
        attention_weights: Attention distribution (seq_len, seq_len)
    """
    # 1. Compute Q, K, V
    Q = W_q(X)
    K = W_k(X)
    V = W_v(X)

    # 2. Compute scaled scores
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

    # 3. Apply mask (optional)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # 4. Apply softmax
    attention_weights = F.softmax(scores, dim=-1)

    # 5. Weighted sum of values
    output = torch.matmul(attention_weights, V)

    return output, attention_weights


# Test the implementation
seq_len = 5
d_model = 8

X = torch.randn(seq_len, d_model)
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)

output, weights = self_attention(X, W_q, W_k, W_v)

print("Input shape:", X.shape)
print("Output shape:", output.shape)
print("Attention weights shape:", weights.shape)

Batch Processing

In practice, we process multiple sequences simultaneously:

python
def batched_self_attention(Q, K, V, mask=None):
    """
    Batched self-attention

    Args:
        Q: (batch_size, seq_len, d_k)
        K: (batch_size, seq_len, d_k)
        V: (batch_size, seq_len, d_v)
        mask: (batch_size, seq_len, seq_len) or broadcastable

    Returns:
        output: (batch_size, seq_len, d_v)
        attention_weights: (batch_size, seq_len, seq_len)
    """
    d_k = Q.size(-1)

    # Compute attention scores
    # Q @ K^T: (batch, seq_len, d_k) @ (batch, d_k, seq_len) -> (batch, seq_len, seq_len)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

    # Apply mask
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # Softmax
    attention_weights = F.softmax(scores, dim=-1)

    # Weighted sum
    output = torch.matmul(attention_weights, V)

    return output, attention_weights


# Example with batches
batch_size = 2
seq_len = 4
d_model = 6

Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

output, weights = batched_self_attention(Q, K, V)

print("Batch size:", batch_size)
print("Sequence length:", seq_len)
print("Output shape:", output.shape)      # (2, 4, 6)
print("Weights shape:", weights.shape)    # (2, 4, 4)

Attention Masks

Attention Mask: A binary or numerical matrix that selectively prevents certain positions from attending to others, typically used to ignore padding tokens or enforce causal ordering in decoders.

Masks control which positions can attend to which.

Padding Mask

Ignore padding tokens in variable-length sequences:

python
def create_padding_mask(seq, pad_token_id=0):
    """
    Create mask for padding tokens

    Args:
        seq: (batch_size, seq_len) token IDs
        pad_token_id: ID used for padding

    Returns:
        mask: (batch_size, 1, seq_len) - 1 for real tokens, 0 for padding
    """
    mask = (seq != pad_token_id).unsqueeze(1)
    return mask


# Example
batch_size = 2
seq_len = 5

# Sequence 1: [1, 2, 3, 0, 0] (padding: 0)
# Sequence 2: [4, 5, 6, 7, 0]
sequences = torch.tensor([
    [1, 2, 3, 0, 0],
    [4, 5, 6, 7, 0]
])

mask = create_padding_mask(sequences)
print("Padding mask shape:", mask.shape)  # (2, 1, 5)
print("\nPadding mask:")
print(mask)

Causal (Look-Ahead) Mask

Prevent attending to future positions (for autoregressive models):

python
def create_causal_mask(seq_len):
    """
    Create causal mask to prevent attending to future positions

    Args:
        seq_len: Sequence length

    Returns:
        mask: (seq_len, seq_len) lower triangular matrix
    """
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask


# Example
seq_len = 5
causal_mask = create_causal_mask(seq_len)

print("Causal mask:")
print(causal_mask)
print("\nInterpretation:")
print("1 = can attend, 0 = cannot attend")
print("Position 0 can only attend to position 0")
print("Position 2 can attend to positions 0, 1, 2")

When to use each mask:

  • Padding mask: Always use when you have variable-length sequences
  • Causal mask: Use in decoder or autoregressive generation
  • Combined: Often both masks are combined (AND operation)

Visualizing Attention Patterns

python
import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(attention_weights, tokens):
    """
    Visualize attention weights as heatmap

    Args:
        attention_weights: (seq_len, seq_len) attention matrix
        tokens: List of token strings
    """
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attention_weights.detach().numpy(),
        annot=True,
        fmt='.2f',
        cmap='Blues',
        xticklabels=tokens,
        yticklabels=tokens,
        cbar_kws={'label': 'Attention Weight'}
    )
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
    plt.title('Self-Attention Weights')
    plt.tight_layout()
    plt.show()


# Example
tokens = ["The", "cat", "sat", "on", "mat"]
seq_len = len(tokens)
d_model = 8

# Create dummy data
X = torch.randn(seq_len, d_model)
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)

output, weights = self_attention(X, W_q, W_k, W_v)

# Visualize
visualize_attention(weights, tokens)

Self-Attention vs Cross-Attention

Cross-Attention: An attention mechanism where queries come from one sequence while keys and values come from a different sequence, enabling interaction between two sequences (e.g., decoder attending to encoder in translation).

Self-Attention

All Q, K, V come from the same sequence:

python
# Self-attention: X -> Q, K, V
Q = W_q(X)
K = W_k(X)
V = W_v(X)

Use case: Understanding relationships within a single sequence

Cross-Attention

Q comes from one sequence, K and V from another:

python
# Cross-attention: decoder attends to encoder
Q = W_q(decoder_input)   # From decoder
K = W_k(encoder_output)  # From encoder
V = W_v(encoder_output)  # From encoder

Use case: Translation, where decoder attends to source sentence

python
def cross_attention(decoder_input, encoder_output, W_q, W_k, W_v):
    """
    Cross-attention mechanism

    Args:
        decoder_input: (batch, tgt_len, d_model)
        encoder_output: (batch, src_len, d_model)
        W_q, W_k, W_v: Weight matrices

    Returns:
        output: (batch, tgt_len, d_model)
    """
    Q = W_q(decoder_input)      # Query from decoder
    K = W_k(encoder_output)     # Key from encoder
    V = W_v(encoder_output)     # Value from encoder

    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, V)

    return output, attention_weights

Computational Complexity

Time Complexity

For a sequence of length

n
and embedding dimension
d
:

  1. Q, K, V projections: O(n × d²)
  2. QK^T computation: O(n² × d)
  3. Softmax: O(n²)
  4. Weighted sum: O(n² × d)

Total: O(n²d + nd²)

For long sequences, n² dominates!

The Quadratic Bottleneck:

Self-attention's O(n²) complexity limits sequence length:

  • Sequence of 512 tokens: 512² = 262K attention scores
  • Sequence of 2048 tokens: 2048² = 4.2M attention scores

This is why transformers traditionally limit context length. Recent advances (sparse attention, linear attention) aim to reduce this.

Space Complexity

Attention weights: O(n²)

For large batches and long sequences, this becomes the memory bottleneck.

Key Properties of Self-Attention

1. Permutation Equivariant

Self-attention treats input as a set (order doesn't matter without positional encodings):

python
# Without positional encodings
X1 = torch.randn(3, 4)
X2 = X1[[2, 0, 1], :]  # Permuted order

# Outputs will be correspondingly permuted
# but relationships preserved

This is why positional encodings are necessary!

2. Parallel Computation

All positions computed simultaneously (unlike RNNs):

python
# Self-attention: all in one operation
output = attention(Q, K, V)  # Parallel

# RNN: sequential
for t in range(seq_len):
    hidden[t] = rnn(input[t], hidden[t-1])  # Sequential

3. Direct Access to All Positions

Every position can directly attend to every other:

  • RNN: Information from position 0 to 100 passes through 100 steps
  • Self-Attention: Direct connection in one step

Practical Implementation Tips

python
class SelfAttention(nn.Module):
    """Production-ready self-attention module"""

    def __init__(self, d_model, dropout=0.1):
        super(SelfAttention, self).__init__()

        self.d_model = d_model
        self.d_k = d_model

        # Combined QKV projection for efficiency
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)

        # Output projection
        self.out_proj = nn.Linear(d_model, d_model)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: (batch, seq_len, d_model)
            mask: (batch, seq_len, seq_len)

        Returns:
            output: (batch, seq_len, d_model)
        """
        batch_size, seq_len, d_model = x.size()

        # Compute Q, K, V in one projection
        qkv = self.qkv_proj(x)  # (batch, seq_len, 3*d_model)

        # Split into Q, K, V
        Q, K, V = qkv.chunk(3, dim=-1)

        # Scaled dot-product attention
        d_k = self.d_k
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Weighted sum
        output = torch.matmul(attention_weights, V)

        # Output projection
        output = self.out_proj(output)

        return output


# Usage
batch_size = 2
seq_len = 10
d_model = 128

attention = SelfAttention(d_model)
x = torch.randn(batch_size, seq_len, d_model)
output = attention(x)

print("Output shape:", output.shape)  # (2, 10, 128)

Optimization techniques:

  1. Fused QKV projection: Compute Q, K, V in one matrix multiply
  2. Flash Attention: Memory-efficient attention algorithm
  3. Gradient checkpointing: Trade compute for memory
  4. Mixed precision: Use float16 for faster computation

Summary

Self-attention is the mechanism that allows transformers to:

  1. Process sequences in parallel (unlike RNNs)
  2. Capture long-range dependencies directly
  3. Compute context-aware representations using Q, K, V

Key Formula:

Attention(Q, K, V) = softmax(QK^T / √d_k) V

Steps:

  1. Project input to Q, K, V
  2. Compute similarity scores (Q × K^T)
  3. Scale by √d_k
  4. Apply softmax for weights
  5. Weighted sum of values

This mechanism is the foundation for all transformer-based models including BERT, GPT, and modern LLMs.