Back
intermediate
Foundation of Transformers

Paper: Attention Is All You Need (Simplified)

A simplified walkthrough of the groundbreaking 2017 paper that introduced the Transformer architecture, revolutionizing natural language processing and deep learning.

25 min read· Transformer· Attention· Paper· Architecture

Paper: Attention Is All You Need (Simplified)

Attention Is All You Need

Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin ()

Read Paper

In 2017, Google researchers published a paper that would fundamentally change deep learning. Instead of using recurrent or convolutional layers, they proposed the Transformer - a model based entirely on attention mechanisms. This lesson breaks down the paper's key concepts in an accessible way.

The Problem with RNNs and LSTMs

Before transformers, sequence-to-sequence tasks (translation, summarization) relied on RNNs and LSTMs.

Sequential Processing Bottleneck

python
# RNN processes tokens one at a time (pseudo-code)
hidden_state = initial_state

for token in sequence:
    hidden_state = rnn_cell(token, hidden_state)
    # Cannot process next token until current one is done

Problems:

  1. Sequential dependency: Can't parallelize - must wait for t-1 to compute t
  2. Memory bottleneck: All information compressed into fixed-size hidden state
  3. Long-range dependencies: Information gets lost over long sequences
  4. Slow training: Processing 1000 tokens requires 1000 sequential steps

The Sequential Bottleneck:

For a sentence with 50 words, an RNN requires 50 sequential operations. With GPUs optimized for parallel computation, this is extremely inefficient. Transformers solve this by processing all tokens simultaneously.

The Core Idea: Self-Attention

Self-Attention: A mechanism that allows each position in a sequence to attend to all positions (including itself) in the same sequence, computing a weighted combination based on how relevant each position is to the current position.

The transformer's key insight is self-attention: every token can directly attend to every other token in one parallel operation.

Intuition: Reading Comprehension

Consider the sentence: "The animal didn't cross the street because it was too tired."

Question: What does "it" refer to?

A human instantly knows "it" = "animal" by attending to relevant context. Self-attention does exactly this:

Attention weights for "it":
- "The"     : 0.01
- "animal"  : 0.75  ← High attention!
- "didn't"  : 0.02
- "cross"   : 0.03
- "street"  : 0.05
- "because" : 0.01
- "it"      : 0.08
- "was"     : 0.02
- "too"     : 0.02
- "tired"   : 0.01

Self-Attention Mechanism

The paper introduces Scaled Dot-Product Attention:

Attention(Q, K, V) = softmax(QK^T / √d_k) V

Where:

  • Q (Query): "What am I looking for?"
  • K (Key): "What do I contain?"
  • V (Value): "What do I actually represent?"

Simplified Implementation

python
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Scaled Dot-Product Attention

    Args:
        Q: Query matrix (batch_size, seq_len, d_k)
        K: Key matrix   (batch_size, seq_len, d_k)
        V: Value matrix (batch_size, seq_len, d_v)
        mask: Optional mask to prevent attending to certain positions

    Returns:
        output: Weighted sum of values (batch_size, seq_len, d_v)
        attention_weights: Attention distribution (batch_size, seq_len, seq_len)
    """
    d_k = Q.size(-1)

    # Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    # scores shape: (batch_size, seq_len, seq_len)

    # Apply mask if provided (for padding or future tokens)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # Apply softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)

    # Compute weighted sum of values
    output = torch.matmul(attention_weights, V)

    return output, attention_weights


# Example usage
batch_size = 2
seq_len = 5
d_model = 8

# Create random Q, K, V matrices
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

output, weights = scaled_dot_product_attention(Q, K, V)

print("Output shape:", output.shape)
print("Attention weights shape:", weights.shape)
print("\nSample attention weights (first sequence, first token):")
print(weights[0, 0, :])  # Shows which tokens this token attended to

Why Scale by √d_k?

Without scaling, dot products grow large with higher dimensions, pushing softmax into regions with extremely small gradients. Scaling by √d_k keeps values in a reasonable range for stable training.

Multi-Head Attention

Multi-Head Attention: A technique that runs multiple attention mechanisms (heads) in parallel, each learning to focus on different aspects of the input, then concatenates their outputs. This allows the model to capture diverse relationships simultaneously.

Instead of single attention, the paper uses multiple attention heads in parallel.

The Concept

Multiple heads allow the model to attend to different aspects simultaneously:

  • Head 1: Might focus on syntactic relationships (subject-verb)
  • Head 2: Might focus on semantic relationships (word meanings)
  • Head 3: Might focus on positional patterns (consecutive words)

Implementation

python
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        """
        Multi-Head Attention Module

        Args:
            d_model: Model dimension (e.g., 512)
            num_heads: Number of attention heads (e.g., 8)
        """
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Dimension per head

        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # Output projection
        self.W_o = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        """Split the last dimension into (num_heads, d_k)"""
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        """Inverse of split_heads"""
        batch_size, num_heads, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

    def forward(self, Q, K, V, mask=None):
        """
        Args:
            Q, K, V: (batch_size, seq_len, d_model)
            mask: Optional mask

        Returns:
            output: (batch_size, seq_len, d_model)
        """
        # Linear projections
        Q = self.W_q(Q)  # (batch_size, seq_len, d_model)
        K = self.W_k(K)
        V = self.W_v(V)

        # Split into multiple heads
        Q = self.split_heads(Q)  # (batch_size, num_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)

        # Apply attention
        output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
        # output: (batch_size, num_heads, seq_len, d_k)

        # Combine heads
        output = self.combine_heads(output)  # (batch_size, seq_len, d_model)

        # Final linear projection
        output = self.W_o(output)

        return output


# Example usage
d_model = 512
num_heads = 8
batch_size = 2
seq_len = 10

mha = MultiHeadAttention(d_model, num_heads)
x = torch.randn(batch_size, seq_len, d_model)

output = mha(x, x, x)  # Self-attention: Q=K=V
print("Output shape:", output.shape)  # (2, 10, 512)

Paper's Configuration:

The original paper uses:

  • d_model = 512 (model dimension)
  • num_heads = 8 (attention heads)
  • d_k = d_v = 64 (dimension per head = 512/8)

This allows parallel computation of 8 different attention patterns.

The Complete Transformer Architecture

The paper's architecture has two main components: Encoder and Decoder.

Encoder Structure

Residual Connection: A shortcut connection that adds the input of a layer directly to its output, helping gradients flow through deep networks and preventing degradation. Mathematically: output = LayerFunction(input) + input.

Each encoder layer contains:

  1. Multi-Head Self-Attention
  2. Add & Norm (Residual connection + Layer Normalization)
  3. Feed-Forward Network
  4. Add & Norm
python
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        """
        Args:
            d_model: Model dimension (512)
            num_heads: Number of attention heads (8)
            d_ff: Feed-forward dimension (2048)
            dropout: Dropout rate
        """
        super(EncoderLayer, self).__init__()

        # Multi-head self-attention
        self.self_attention = MultiHeadAttention(d_model, num_heads)

        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: Input (batch_size, seq_len, d_model)
            mask: Attention mask

        Returns:
            output: (batch_size, seq_len, d_model)
        """
        # Self-attention with residual connection and normalization
        attn_output = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))

        # Feed-forward with residual connection and normalization
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))

        return x

Decoder Structure

Each decoder layer contains:

  1. Masked Multi-Head Self-Attention (can't see future tokens)
  2. Add & Norm
  3. Multi-Head Cross-Attention (attends to encoder output)
  4. Add & Norm
  5. Feed-Forward Network
  6. Add & Norm
python
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()

        # Masked self-attention (for target sequence)
        self.masked_self_attention = MultiHeadAttention(d_model, num_heads)

        # Cross-attention (attending to encoder output)
        self.cross_attention = MultiHeadAttention(d_model, num_heads)

        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: Decoder input (batch_size, tgt_len, d_model)
            encoder_output: Encoder output (batch_size, src_len, d_model)
            src_mask: Encoder mask
            tgt_mask: Decoder mask (prevents looking at future tokens)

        Returns:
            output: (batch_size, tgt_len, d_model)
        """
        # Masked self-attention on target sequence
        self_attn_output = self.masked_self_attention(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn_output))

        # Cross-attention: Query from decoder, Key & Value from encoder
        cross_attn_output = self.cross_attention(x, encoder_output, encoder_output, src_mask)
        x = self.norm2(x + self.dropout(cross_attn_output))

        # Feed-forward
        ffn_output = self.ffn(x)
        x = self.norm3(x + self.dropout(ffn_output))

        return x

Position-wise Feed-Forward Networks

After attention, each position passes through the same feed-forward network independently.

FFN(x) = max(0, xW₁ + b₁)W₂ + b₂
python
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        """
        Args:
            d_model: Input/output dimension (512)
            d_ff: Hidden layer dimension (2048)
        """
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        return self.fc2(self.relu(self.fc1(x)))

Paper Details:

  • Inner layer dimension: d_ff = 2048
  • Output dimension: d_model = 512
  • This 4x expansion provides representational capacity

Positional Encoding

Positional Encoding: A technique that injects information about token positions into the input embeddings using fixed mathematical functions (like sine and cosine), compensating for the transformer's lack of inherent sequential ordering.

Since transformers process all tokens in parallel, they need explicit position information.

Sinusoidal Positional Encoding

The paper uses sine and cosine functions:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
python
import numpy as np

def get_positional_encoding(max_seq_len, d_model):
    """
    Generate sinusoidal positional encodings

    Args:
        max_seq_len: Maximum sequence length
        d_model: Model dimension

    Returns:
        pos_encoding: (max_seq_len, d_model)
    """
    pos_encoding = np.zeros((max_seq_len, d_model))

    for pos in range(max_seq_len):
        for i in range(0, d_model, 2):
            # Apply sine to even indices
            pos_encoding[pos, i] = np.sin(pos / (10000 ** (i / d_model)))

            # Apply cosine to odd indices
            if i + 1 < d_model:
                pos_encoding[pos, i + 1] = np.cos(pos / (10000 ** (i / d_model)))

    return torch.FloatTensor(pos_encoding)


# Visualize positional encodings
import matplotlib.pyplot as plt

pe = get_positional_encoding(100, 512)
plt.figure(figsize=(12, 6))
plt.imshow(pe.numpy(), cmap='RdBu', aspect='auto')
plt.xlabel('Embedding Dimension')
plt.ylabel('Position')
plt.colorbar()
plt.title('Sinusoidal Positional Encoding')

Training Details from the Paper

Optimizer: Adam with Custom Learning Rate Schedule

python
class NoamOptimizer:
    """Learning rate schedule from the paper"""

    def __init__(self, optimizer, d_model, warmup_steps=4000):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.step_num = 0

    def step(self):
        self.step_num += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        self.optimizer.step()

    def get_lr(self):
        """
        Learning rate formula from paper:
        lr = d_model^(-0.5) * min(step^(-0.5), step * warmup_steps^(-1.5))
        """
        return (self.d_model ** -0.5) * min(
            self.step_num ** -0.5,
            self.step_num * (self.warmup_steps ** -1.5)
        )


# Usage
optimizer = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
scheduler = NoamOptimizer(optimizer, d_model=512, warmup_steps=4000)

Regularization Techniques

1. Residual Dropout

python
dropout_rate = 0.1  # Applied to each sub-layer before residual connection

2. Label Smoothing

python
class LabelSmoothingLoss(nn.Module):
    def __init__(self, vocab_size, smoothing=0.1):
        super(LabelSmoothingLoss, self).__init__()
        self.criterion = nn.KLDivLoss(reduction='batchmean')
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.vocab_size = vocab_size

    def forward(self, pred, target):
        """
        Args:
            pred: (batch_size, vocab_size) - log probabilities
            target: (batch_size,) - true labels
        """
        true_dist = torch.zeros_like(pred)
        true_dist.fill_(self.smoothing / (self.vocab_size - 1))
        true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        return self.criterion(pred, true_dist)

Training Hyperparameters (Base Model):

  • Layers: N = 6 (both encoder and decoder)
  • d_model: 512
  • d_ff: 2048
  • Attention heads: 8
  • Dropout: 0.1
  • Label smoothing: ε = 0.1
  • Training steps: 100,000
  • Batch size: ~25,000 source + target tokens

Results and Impact

WMT 2014 Translation Tasks

The transformer achieved state-of-the-art results:

ModelBLEU Score (EN-DE)BLEU Score (EN-FR)
Previous SOTA26.340.4
Transformer (base)27.338.1
Transformer (big)28.441.8

Training Efficiency

Transformer advantages:

  • Training time: 12 hours (8 GPUs) vs. 3.5 days for previous models
  • Parallelization: Process entire sequences at once
  • Inference speed: Faster than RNN-based models

Why "Attention Is All You Need"?

The paper's bold title emphasizes that attention alone is sufficient:

  1. No recurrence: Eliminated RNN/LSTM layers entirely
  2. No convolution: Didn't need CNNs for feature extraction
  3. Pure attention: Self-attention handles all sequence modeling

Paper's Legacy:

This architecture became the foundation for:

  • BERT (2018): Encoder-only transformer
  • GPT (2018-present): Decoder-only transformer
  • T5 (2019): Text-to-text transformer
  • Vision Transformers (2020): Transformers for images
  • AlphaFold (2021): Protein structure prediction

The transformer architecture now dominates AI research across domains.

Key Takeaways from the Paper

  1. Self-Attention: Allows each token to attend to all others in parallel
  2. Multi-Head Attention: Multiple attention patterns capture different relationships
  3. Positional Encoding: Provides position information without recurrence
  4. Scalability: Parallelization enables training on massive datasets
  5. Generalization: Architecture works across different domains (text, vision, audio)

Summary

"Attention Is All You Need" introduced the transformer architecture with:

  • Scaled dot-product attention for efficient similarity computation
  • Multi-head attention for diverse relationship modeling
  • Encoder-decoder structure for sequence-to-sequence tasks
  • Positional encodings to maintain sequential information
  • Complete parallelization for faster training and inference

This paper's impact extends far beyond NLP, influencing nearly every area of modern deep learning.