Paper: Attention Is All You Need (Simplified)
Attention Is All You Need
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin ()
Read PaperIn 2017, Google researchers published a paper that would fundamentally change deep learning. Instead of using recurrent or convolutional layers, they proposed the Transformer - a model based entirely on attention mechanisms. This lesson breaks down the paper's key concepts in an accessible way.
The Problem with RNNs and LSTMs
Before transformers, sequence-to-sequence tasks (translation, summarization) relied on RNNs and LSTMs.
Sequential Processing Bottleneck
# RNN processes tokens one at a time (pseudo-code)
hidden_state = initial_state
for token in sequence:
hidden_state = rnn_cell(token, hidden_state)
# Cannot process next token until current one is done
Problems:
- Sequential dependency: Can't parallelize - must wait for t-1 to compute t
- Memory bottleneck: All information compressed into fixed-size hidden state
- Long-range dependencies: Information gets lost over long sequences
- Slow training: Processing 1000 tokens requires 1000 sequential steps
The Sequential Bottleneck:
For a sentence with 50 words, an RNN requires 50 sequential operations. With GPUs optimized for parallel computation, this is extremely inefficient. Transformers solve this by processing all tokens simultaneously.
The Core Idea: Self-Attention
Self-Attention: A mechanism that allows each position in a sequence to attend to all positions (including itself) in the same sequence, computing a weighted combination based on how relevant each position is to the current position.
The transformer's key insight is self-attention: every token can directly attend to every other token in one parallel operation.
Intuition: Reading Comprehension
Consider the sentence: "The animal didn't cross the street because it was too tired."
Question: What does "it" refer to?
A human instantly knows "it" = "animal" by attending to relevant context. Self-attention does exactly this:
Attention weights for "it":
- "The" : 0.01
- "animal" : 0.75 ← High attention!
- "didn't" : 0.02
- "cross" : 0.03
- "street" : 0.05
- "because" : 0.01
- "it" : 0.08
- "was" : 0.02
- "too" : 0.02
- "tired" : 0.01
Self-Attention Mechanism
The paper introduces Scaled Dot-Product Attention:
Attention(Q, K, V) = softmax(QK^T / √d_k) V
Where:
- Q (Query): "What am I looking for?"
- K (Key): "What do I contain?"
- V (Value): "What do I actually represent?"
Simplified Implementation
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Scaled Dot-Product Attention
Args:
Q: Query matrix (batch_size, seq_len, d_k)
K: Key matrix (batch_size, seq_len, d_k)
V: Value matrix (batch_size, seq_len, d_v)
mask: Optional mask to prevent attending to certain positions
Returns:
output: Weighted sum of values (batch_size, seq_len, d_v)
attention_weights: Attention distribution (batch_size, seq_len, seq_len)
"""
d_k = Q.size(-1)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
# scores shape: (batch_size, seq_len, seq_len)
# Apply mask if provided (for padding or future tokens)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Apply softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1)
# Compute weighted sum of values
output = torch.matmul(attention_weights, V)
return output, attention_weights
# Example usage
batch_size = 2
seq_len = 5
d_model = 8
# Create random Q, K, V matrices
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)
output, weights = scaled_dot_product_attention(Q, K, V)
print("Output shape:", output.shape)
print("Attention weights shape:", weights.shape)
print("\nSample attention weights (first sequence, first token):")
print(weights[0, 0, :]) # Shows which tokens this token attended to
Why Scale by √d_k?
Without scaling, dot products grow large with higher dimensions, pushing softmax into regions with extremely small gradients. Scaling by √d_k keeps values in a reasonable range for stable training.
Multi-Head Attention
Multi-Head Attention: A technique that runs multiple attention mechanisms (heads) in parallel, each learning to focus on different aspects of the input, then concatenates their outputs. This allows the model to capture diverse relationships simultaneously.
Instead of single attention, the paper uses multiple attention heads in parallel.
The Concept
Multiple heads allow the model to attend to different aspects simultaneously:
- Head 1: Might focus on syntactic relationships (subject-verb)
- Head 2: Might focus on semantic relationships (word meanings)
- Head 3: Might focus on positional patterns (consecutive words)
Implementation
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
"""
Multi-Head Attention Module
Args:
d_model: Model dimension (e.g., 512)
num_heads: Number of attention heads (e.g., 8)
"""
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # Dimension per head
# Linear projections for Q, K, V
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# Output projection
self.W_o = nn.Linear(d_model, d_model)
def split_heads(self, x):
"""Split the last dimension into (num_heads, d_k)"""
batch_size, seq_len, d_model = x.size()
return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
def combine_heads(self, x):
"""Inverse of split_heads"""
batch_size, num_heads, seq_len, d_k = x.size()
return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
def forward(self, Q, K, V, mask=None):
"""
Args:
Q, K, V: (batch_size, seq_len, d_model)
mask: Optional mask
Returns:
output: (batch_size, seq_len, d_model)
"""
# Linear projections
Q = self.W_q(Q) # (batch_size, seq_len, d_model)
K = self.W_k(K)
V = self.W_v(V)
# Split into multiple heads
Q = self.split_heads(Q) # (batch_size, num_heads, seq_len, d_k)
K = self.split_heads(K)
V = self.split_heads(V)
# Apply attention
output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
# output: (batch_size, num_heads, seq_len, d_k)
# Combine heads
output = self.combine_heads(output) # (batch_size, seq_len, d_model)
# Final linear projection
output = self.W_o(output)
return output
# Example usage
d_model = 512
num_heads = 8
batch_size = 2
seq_len = 10
mha = MultiHeadAttention(d_model, num_heads)
x = torch.randn(batch_size, seq_len, d_model)
output = mha(x, x, x) # Self-attention: Q=K=V
print("Output shape:", output.shape) # (2, 10, 512)
Paper's Configuration:
The original paper uses:
- d_model = 512 (model dimension)
- num_heads = 8 (attention heads)
- d_k = d_v = 64 (dimension per head = 512/8)
This allows parallel computation of 8 different attention patterns.
The Complete Transformer Architecture
The paper's architecture has two main components: Encoder and Decoder.
Encoder Structure
Residual Connection: A shortcut connection that adds the input of a layer directly to its output, helping gradients flow through deep networks and preventing degradation. Mathematically: output = LayerFunction(input) + input.
Each encoder layer contains:
- Multi-Head Self-Attention
- Add & Norm (Residual connection + Layer Normalization)
- Feed-Forward Network
- Add & Norm
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
"""
Args:
d_model: Model dimension (512)
num_heads: Number of attention heads (8)
d_ff: Feed-forward dimension (2048)
dropout: Dropout rate
"""
super(EncoderLayer, self).__init__()
# Multi-head self-attention
self.self_attention = MultiHeadAttention(d_model, num_heads)
# Feed-forward network
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
# Layer normalization
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
Args:
x: Input (batch_size, seq_len, d_model)
mask: Attention mask
Returns:
output: (batch_size, seq_len, d_model)
"""
# Self-attention with residual connection and normalization
attn_output = self.self_attention(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# Feed-forward with residual connection and normalization
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_output))
return x
Decoder Structure
Each decoder layer contains:
- Masked Multi-Head Self-Attention (can't see future tokens)
- Add & Norm
- Multi-Head Cross-Attention (attends to encoder output)
- Add & Norm
- Feed-Forward Network
- Add & Norm
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(DecoderLayer, self).__init__()
# Masked self-attention (for target sequence)
self.masked_self_attention = MultiHeadAttention(d_model, num_heads)
# Cross-attention (attending to encoder output)
self.cross_attention = MultiHeadAttention(d_model, num_heads)
# Feed-forward network
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
# Layer normalization
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
"""
Args:
x: Decoder input (batch_size, tgt_len, d_model)
encoder_output: Encoder output (batch_size, src_len, d_model)
src_mask: Encoder mask
tgt_mask: Decoder mask (prevents looking at future tokens)
Returns:
output: (batch_size, tgt_len, d_model)
"""
# Masked self-attention on target sequence
self_attn_output = self.masked_self_attention(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(self_attn_output))
# Cross-attention: Query from decoder, Key & Value from encoder
cross_attn_output = self.cross_attention(x, encoder_output, encoder_output, src_mask)
x = self.norm2(x + self.dropout(cross_attn_output))
# Feed-forward
ffn_output = self.ffn(x)
x = self.norm3(x + self.dropout(ffn_output))
return x
Position-wise Feed-Forward Networks
After attention, each position passes through the same feed-forward network independently.
FFN(x) = max(0, xW₁ + b₁)W₂ + b₂
class PositionWiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff):
"""
Args:
d_model: Input/output dimension (512)
d_ff: Hidden layer dimension (2048)
"""
super(PositionWiseFeedForward, self).__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()
def forward(self, x):
# x: (batch_size, seq_len, d_model)
return self.fc2(self.relu(self.fc1(x)))
Paper Details:
- Inner layer dimension: d_ff = 2048
- Output dimension: d_model = 512
- This 4x expansion provides representational capacity
Positional Encoding
Positional Encoding: A technique that injects information about token positions into the input embeddings using fixed mathematical functions (like sine and cosine), compensating for the transformer's lack of inherent sequential ordering.
Since transformers process all tokens in parallel, they need explicit position information.
Sinusoidal Positional Encoding
The paper uses sine and cosine functions:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
import numpy as np
def get_positional_encoding(max_seq_len, d_model):
"""
Generate sinusoidal positional encodings
Args:
max_seq_len: Maximum sequence length
d_model: Model dimension
Returns:
pos_encoding: (max_seq_len, d_model)
"""
pos_encoding = np.zeros((max_seq_len, d_model))
for pos in range(max_seq_len):
for i in range(0, d_model, 2):
# Apply sine to even indices
pos_encoding[pos, i] = np.sin(pos / (10000 ** (i / d_model)))
# Apply cosine to odd indices
if i + 1 < d_model:
pos_encoding[pos, i + 1] = np.cos(pos / (10000 ** (i / d_model)))
return torch.FloatTensor(pos_encoding)
# Visualize positional encodings
import matplotlib.pyplot as plt
pe = get_positional_encoding(100, 512)
plt.figure(figsize=(12, 6))
plt.imshow(pe.numpy(), cmap='RdBu', aspect='auto')
plt.xlabel('Embedding Dimension')
plt.ylabel('Position')
plt.colorbar()
plt.title('Sinusoidal Positional Encoding')
Training Details from the Paper
Optimizer: Adam with Custom Learning Rate Schedule
class NoamOptimizer:
"""Learning rate schedule from the paper"""
def __init__(self, optimizer, d_model, warmup_steps=4000):
self.optimizer = optimizer
self.d_model = d_model
self.warmup_steps = warmup_steps
self.step_num = 0
def step(self):
self.step_num += 1
lr = self.get_lr()
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.optimizer.step()
def get_lr(self):
"""
Learning rate formula from paper:
lr = d_model^(-0.5) * min(step^(-0.5), step * warmup_steps^(-1.5))
"""
return (self.d_model ** -0.5) * min(
self.step_num ** -0.5,
self.step_num * (self.warmup_steps ** -1.5)
)
# Usage
optimizer = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
scheduler = NoamOptimizer(optimizer, d_model=512, warmup_steps=4000)
Regularization Techniques
1. Residual Dropout
dropout_rate = 0.1 # Applied to each sub-layer before residual connection
2. Label Smoothing
class LabelSmoothingLoss(nn.Module):
def __init__(self, vocab_size, smoothing=0.1):
super(LabelSmoothingLoss, self).__init__()
self.criterion = nn.KLDivLoss(reduction='batchmean')
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.vocab_size = vocab_size
def forward(self, pred, target):
"""
Args:
pred: (batch_size, vocab_size) - log probabilities
target: (batch_size,) - true labels
"""
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (self.vocab_size - 1))
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
return self.criterion(pred, true_dist)
Training Hyperparameters (Base Model):
- Layers: N = 6 (both encoder and decoder)
- d_model: 512
- d_ff: 2048
- Attention heads: 8
- Dropout: 0.1
- Label smoothing: ε = 0.1
- Training steps: 100,000
- Batch size: ~25,000 source + target tokens
Results and Impact
WMT 2014 Translation Tasks
The transformer achieved state-of-the-art results:
| Model | BLEU Score (EN-DE) | BLEU Score (EN-FR) |
|---|---|---|
| Previous SOTA | 26.3 | 40.4 |
| Transformer (base) | 27.3 | 38.1 |
| Transformer (big) | 28.4 | 41.8 |
Training Efficiency
Transformer advantages:
- Training time: 12 hours (8 GPUs) vs. 3.5 days for previous models
- Parallelization: Process entire sequences at once
- Inference speed: Faster than RNN-based models
Why "Attention Is All You Need"?
The paper's bold title emphasizes that attention alone is sufficient:
- No recurrence: Eliminated RNN/LSTM layers entirely
- No convolution: Didn't need CNNs for feature extraction
- Pure attention: Self-attention handles all sequence modeling
Paper's Legacy:
This architecture became the foundation for:
- BERT (2018): Encoder-only transformer
- GPT (2018-present): Decoder-only transformer
- T5 (2019): Text-to-text transformer
- Vision Transformers (2020): Transformers for images
- AlphaFold (2021): Protein structure prediction
The transformer architecture now dominates AI research across domains.
Key Takeaways from the Paper
- Self-Attention: Allows each token to attend to all others in parallel
- Multi-Head Attention: Multiple attention patterns capture different relationships
- Positional Encoding: Provides position information without recurrence
- Scalability: Parallelization enables training on massive datasets
- Generalization: Architecture works across different domains (text, vision, audio)
Summary
"Attention Is All You Need" introduced the transformer architecture with:
- Scaled dot-product attention for efficient similarity computation
- Multi-head attention for diverse relationship modeling
- Encoder-decoder structure for sequence-to-sequence tasks
- Positional encodings to maintain sequential information
- Complete parallelization for faster training and inference
This paper's impact extends far beyond NLP, influencing nearly every area of modern deep learning.