Understanding Self-Attention
Self-attention is the core innovation that makes transformers powerful. Unlike RNNs that process sequences sequentially, self-attention allows each position to attend to all positions in a single operation. This lesson provides a technical deep-dive into how it works.
The Attention Intuition
Consider translating: "The bank of the river was flooded."
The word "bank" is ambiguous:
- Financial institution?
- Side of a river?
Humans use context ("river", "flooded") to disambiguate. Self-attention does the same computationally.
How Self-Attention Resolves Ambiguity
Context weights for "bank":
- "The" : 0.05
- "bank" : 0.10
- "of" : 0.05
- "the" : 0.03
- "river" : 0.45 ← High attention!
- "was" : 0.02
- "flooded": 0.30 ← High attention!
The model learns that "river" and "flooded" provide crucial context for understanding "bank".
The Query-Key-Value Framework
Query-Key-Value (QKV) Framework: A retrieval-inspired mechanism where Query matrices represent "what to look for," Key matrices represent "what is offered," and Value matrices contain "the actual information to retrieve." Attention compares queries with keys to weight values.
Self-attention uses three learned transformations of the input: Query (Q), Key (K), and Value (V).
Analogy: Library Search
Think of attention like searching a library:
- Query (Q): Your search request ("I need information about transformers")
- Keys (K): Book titles/indexes that describe what each book contains
- Values (V): The actual content of the books
The attention mechanism:
- Compares your query against all keys (which books match?)
- Assigns scores (relevance scores)
- Retrieves and combines values (actual information) based on scores
Mathematical Formulation
Attention(Q, K, V) = softmax(QK^T / √d_k) V
Let's break this down step by step.
Step-by-Step Attention Computation
Setup: Input Embeddings
import torch
import torch.nn as nn
import torch.nn.functional as F
# Example: Simple sentence with 3 words
# "The cat sat"
seq_len = 3
d_model = 4 # Small for illustration
# Random embeddings (in practice, these come from an embedding layer)
X = torch.randn(seq_len, d_model)
print("Input embeddings shape:", X.shape) # (3, 4)
print("\nInput embeddings:")
print(X)
Step 1: Create Q, K, V Matrices
# Learned weight matrices
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)
# Apply transformations
Q = W_q(X) # Query: "What am I looking for?"
K = W_k(X) # Key: "What do I contain?"
V = W_v(X) # Value: "What information do I output?"
print("Q shape:", Q.shape) # (3, 4)
print("K shape:", K.shape) # (3, 4)
print("V shape:", V.shape) # (3, 4)
Why three separate matrices?
Different transformations allow the model to separate:
- What each position is looking for (Q)
- What each position offers (K)
- What each position contains (V)
This separation provides more expressive power than using the same matrix for all three.
Step 2: Compute Attention Scores
Attention Scores: Numerical values computed by taking the dot product of query and key vectors, indicating how much each position should "attend to" (focus on) every other position in the sequence.
# Dot product between queries and keys
# Q: (seq_len, d_model) = (3, 4)
# K^T: (d_model, seq_len) = (4, 3)
# Scores: (seq_len, seq_len) = (3, 3)
scores = torch.matmul(Q, K.transpose(-2, -1))
print("Raw scores shape:", scores.shape) # (3, 3)
print("\nRaw scores:")
print(scores)
The score matrix tells us:
scores[i, j] = similarity between query_i and key_j
Interpretation:
- : How much does "The" attend to "The"?
scores[0, 0] - : How much does "The" attend to "cat"?
scores[0, 1] - : How much does "The" attend to "sat"?
scores[0, 2]
Step 3: Scale Scores
d_k = Q.size(-1) # Dimension of queries/keys
scaled_scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
print("Scaled scores:")
print(scaled_scores)
Why scaling matters:
Without scaling, as
d_kExample with d_k=64:
- Dot product might be ~25 (large)
- After softmax: nearly all weight on one position
- Gradients: ~0 (vanishing)
Scaling by √d_k keeps values reasonable.
Step 4: Apply Softmax
attention_weights = F.softmax(scaled_scores, dim=-1)
print("Attention weights (sum to 1 per row):")
print(attention_weights)
print("\nRow sums (should be 1.0):")
print(attention_weights.sum(dim=-1))
Softmax ensures:
- All weights are positive (0 to 1)
- Weights for each query sum to 1 (probability distribution)
Step 5: Weighted Sum of Values
output = torch.matmul(attention_weights, V)
print("Output shape:", output.shape) # (3, 4)
print("\nOutput:")
print(output)
Each output position is a weighted combination of all value vectors:
output[0] = attention_weights[0,0] * V[0]
+ attention_weights[0,1] * V[1]
+ attention_weights[0,2] * V[2]
Complete Self-Attention Implementation
def self_attention(X, W_q, W_k, W_v, mask=None):
"""
Complete self-attention mechanism
Args:
X: Input embeddings (seq_len, d_model)
W_q, W_k, W_v: Query, Key, Value weight matrices
mask: Optional attention mask
Returns:
output: Context-aware representations (seq_len, d_model)
attention_weights: Attention distribution (seq_len, seq_len)
"""
# 1. Compute Q, K, V
Q = W_q(X)
K = W_k(X)
V = W_v(X)
# 2. Compute scaled scores
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
# 3. Apply mask (optional)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 4. Apply softmax
attention_weights = F.softmax(scores, dim=-1)
# 5. Weighted sum of values
output = torch.matmul(attention_weights, V)
return output, attention_weights
# Test the implementation
seq_len = 5
d_model = 8
X = torch.randn(seq_len, d_model)
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)
output, weights = self_attention(X, W_q, W_k, W_v)
print("Input shape:", X.shape)
print("Output shape:", output.shape)
print("Attention weights shape:", weights.shape)
Batch Processing
In practice, we process multiple sequences simultaneously:
def batched_self_attention(Q, K, V, mask=None):
"""
Batched self-attention
Args:
Q: (batch_size, seq_len, d_k)
K: (batch_size, seq_len, d_k)
V: (batch_size, seq_len, d_v)
mask: (batch_size, seq_len, seq_len) or broadcastable
Returns:
output: (batch_size, seq_len, d_v)
attention_weights: (batch_size, seq_len, seq_len)
"""
d_k = Q.size(-1)
# Compute attention scores
# Q @ K^T: (batch, seq_len, d_k) @ (batch, d_k, seq_len) -> (batch, seq_len, seq_len)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
# Apply mask
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Softmax
attention_weights = F.softmax(scores, dim=-1)
# Weighted sum
output = torch.matmul(attention_weights, V)
return output, attention_weights
# Example with batches
batch_size = 2
seq_len = 4
d_model = 6
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)
output, weights = batched_self_attention(Q, K, V)
print("Batch size:", batch_size)
print("Sequence length:", seq_len)
print("Output shape:", output.shape) # (2, 4, 6)
print("Weights shape:", weights.shape) # (2, 4, 4)
Attention Masks
Attention Mask: A binary or numerical matrix that selectively prevents certain positions from attending to others, typically used to ignore padding tokens or enforce causal ordering in decoders.
Masks control which positions can attend to which.
Padding Mask
Ignore padding tokens in variable-length sequences:
def create_padding_mask(seq, pad_token_id=0):
"""
Create mask for padding tokens
Args:
seq: (batch_size, seq_len) token IDs
pad_token_id: ID used for padding
Returns:
mask: (batch_size, 1, seq_len) - 1 for real tokens, 0 for padding
"""
mask = (seq != pad_token_id).unsqueeze(1)
return mask
# Example
batch_size = 2
seq_len = 5
# Sequence 1: [1, 2, 3, 0, 0] (padding: 0)
# Sequence 2: [4, 5, 6, 7, 0]
sequences = torch.tensor([
[1, 2, 3, 0, 0],
[4, 5, 6, 7, 0]
])
mask = create_padding_mask(sequences)
print("Padding mask shape:", mask.shape) # (2, 1, 5)
print("\nPadding mask:")
print(mask)
Causal (Look-Ahead) Mask
Prevent attending to future positions (for autoregressive models):
def create_causal_mask(seq_len):
"""
Create causal mask to prevent attending to future positions
Args:
seq_len: Sequence length
Returns:
mask: (seq_len, seq_len) lower triangular matrix
"""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask
# Example
seq_len = 5
causal_mask = create_causal_mask(seq_len)
print("Causal mask:")
print(causal_mask)
print("\nInterpretation:")
print("1 = can attend, 0 = cannot attend")
print("Position 0 can only attend to position 0")
print("Position 2 can attend to positions 0, 1, 2")
When to use each mask:
- Padding mask: Always use when you have variable-length sequences
- Causal mask: Use in decoder or autoregressive generation
- Combined: Often both masks are combined (AND operation)
Visualizing Attention Patterns
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(attention_weights, tokens):
"""
Visualize attention weights as heatmap
Args:
attention_weights: (seq_len, seq_len) attention matrix
tokens: List of token strings
"""
plt.figure(figsize=(10, 8))
sns.heatmap(
attention_weights.detach().numpy(),
annot=True,
fmt='.2f',
cmap='Blues',
xticklabels=tokens,
yticklabels=tokens,
cbar_kws={'label': 'Attention Weight'}
)
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Self-Attention Weights')
plt.tight_layout()
plt.show()
# Example
tokens = ["The", "cat", "sat", "on", "mat"]
seq_len = len(tokens)
d_model = 8
# Create dummy data
X = torch.randn(seq_len, d_model)
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)
output, weights = self_attention(X, W_q, W_k, W_v)
# Visualize
visualize_attention(weights, tokens)
Self-Attention vs Cross-Attention
Cross-Attention: An attention mechanism where queries come from one sequence while keys and values come from a different sequence, enabling interaction between two sequences (e.g., decoder attending to encoder in translation).
Self-Attention
All Q, K, V come from the same sequence:
# Self-attention: X -> Q, K, V
Q = W_q(X)
K = W_k(X)
V = W_v(X)
Use case: Understanding relationships within a single sequence
Cross-Attention
Q comes from one sequence, K and V from another:
# Cross-attention: decoder attends to encoder
Q = W_q(decoder_input) # From decoder
K = W_k(encoder_output) # From encoder
V = W_v(encoder_output) # From encoder
Use case: Translation, where decoder attends to source sentence
def cross_attention(decoder_input, encoder_output, W_q, W_k, W_v):
"""
Cross-attention mechanism
Args:
decoder_input: (batch, tgt_len, d_model)
encoder_output: (batch, src_len, d_model)
W_q, W_k, W_v: Weight matrices
Returns:
output: (batch, tgt_len, d_model)
"""
Q = W_q(decoder_input) # Query from decoder
K = W_k(encoder_output) # Key from encoder
V = W_v(encoder_output) # Value from encoder
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output, attention_weights
Computational Complexity
Time Complexity
For a sequence of length
nd- Q, K, V projections: O(n × d²)
- QK^T computation: O(n² × d)
- Softmax: O(n²)
- Weighted sum: O(n² × d)
Total: O(n²d + nd²)
For long sequences, n² dominates!
The Quadratic Bottleneck:
Self-attention's O(n²) complexity limits sequence length:
- Sequence of 512 tokens: 512² = 262K attention scores
- Sequence of 2048 tokens: 2048² = 4.2M attention scores
This is why transformers traditionally limit context length. Recent advances (sparse attention, linear attention) aim to reduce this.
Space Complexity
Attention weights: O(n²)
For large batches and long sequences, this becomes the memory bottleneck.
Key Properties of Self-Attention
1. Permutation Equivariant
Self-attention treats input as a set (order doesn't matter without positional encodings):
# Without positional encodings
X1 = torch.randn(3, 4)
X2 = X1[[2, 0, 1], :] # Permuted order
# Outputs will be correspondingly permuted
# but relationships preserved
This is why positional encodings are necessary!
2. Parallel Computation
All positions computed simultaneously (unlike RNNs):
# Self-attention: all in one operation
output = attention(Q, K, V) # Parallel
# RNN: sequential
for t in range(seq_len):
hidden[t] = rnn(input[t], hidden[t-1]) # Sequential
3. Direct Access to All Positions
Every position can directly attend to every other:
- RNN: Information from position 0 to 100 passes through 100 steps
- Self-Attention: Direct connection in one step
Practical Implementation Tips
class SelfAttention(nn.Module):
"""Production-ready self-attention module"""
def __init__(self, d_model, dropout=0.1):
super(SelfAttention, self).__init__()
self.d_model = d_model
self.d_k = d_model
# Combined QKV projection for efficiency
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
# Output projection
self.out_proj = nn.Linear(d_model, d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
Args:
x: (batch, seq_len, d_model)
mask: (batch, seq_len, seq_len)
Returns:
output: (batch, seq_len, d_model)
"""
batch_size, seq_len, d_model = x.size()
# Compute Q, K, V in one projection
qkv = self.qkv_proj(x) # (batch, seq_len, 3*d_model)
# Split into Q, K, V
Q, K, V = qkv.chunk(3, dim=-1)
# Scaled dot-product attention
d_k = self.d_k
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# Weighted sum
output = torch.matmul(attention_weights, V)
# Output projection
output = self.out_proj(output)
return output
# Usage
batch_size = 2
seq_len = 10
d_model = 128
attention = SelfAttention(d_model)
x = torch.randn(batch_size, seq_len, d_model)
output = attention(x)
print("Output shape:", output.shape) # (2, 10, 128)
Optimization techniques:
- Fused QKV projection: Compute Q, K, V in one matrix multiply
- Flash Attention: Memory-efficient attention algorithm
- Gradient checkpointing: Trade compute for memory
- Mixed precision: Use float16 for faster computation
Summary
Self-attention is the mechanism that allows transformers to:
- Process sequences in parallel (unlike RNNs)
- Capture long-range dependencies directly
- Compute context-aware representations using Q, K, V
Key Formula:
Attention(Q, K, V) = softmax(QK^T / √d_k) V
Steps:
- Project input to Q, K, V
- Compute similarity scores (Q × K^T)
- Scale by √d_k
- Apply softmax for weights
- Weighted sum of values
This mechanism is the foundation for all transformer-based models including BERT, GPT, and modern LLMs.