Encoder-Decoder Architecture
The original transformer uses an encoder-decoder architecture for sequence-to-sequence tasks like translation. This lesson explains how both components work together and how they differ from encoder-only (BERT) and decoder-only (GPT) models.
The Encoder-Decoder Paradigm
Sequence-to-Sequence (Seq2Seq): A neural network architecture pattern that transforms one sequence into another potentially different-length sequence, using an encoder to compress the input and a decoder to generate the output.
Sequence-to-Sequence Tasks
Some tasks require mapping one sequence to another:
- Machine Translation: English → French
- Summarization: Long document → Summary
- Speech Recognition: Audio → Text
- Image Captioning: Image → Text description
The Two-Stage Process
Encoder: Understands the input
- Processes source sequence
- Creates rich representations
- Bidirectional (can see entire input)
Decoder: Generates the output
- Creates target sequence one token at a time
- Attends to encoder outputs
- Autoregressive (sees only previous outputs)
Encoder vs Decoder:
- Encoder: "What does this mean?" (comprehension)
- Decoder: "How do I say this?" (generation)
Think of translation: the encoder understands the English sentence, the decoder writes the French equivalent.
The Encoder
The encoder processes the entire input sequence to create contextual representations.
Encoder Layer Structure
Each encoder layer has two sub-layers:
- Multi-Head Self-Attention
- Feed-Forward Network
Both use residual connections and layer normalization.
Complete Encoder Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class EncoderLayer(nn.Module):
"""Single encoder layer from 'Attention Is All You Need'"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
"""
Args:
d_model: Model dimension (e.g., 512)
num_heads: Number of attention heads (e.g., 8)
d_ff: Feed-forward inner dimension (e.g., 2048)
dropout: Dropout probability
"""
super(EncoderLayer, self).__init__()
# Multi-head self-attention
self.self_attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
# Feed-forward network
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
# Layer normalization
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Dropout
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, src_mask=None):
"""
Args:
x: Input (batch, seq_len, d_model)
src_mask: Attention mask for source
Returns:
output: (batch, seq_len, d_model)
"""
# Self-attention with residual connection
attn_output, _ = self.self_attn(
x, x, x,
attn_mask=src_mask,
need_weights=False
)
x = x + self.dropout1(attn_output)
x = self.norm1(x)
# Feed-forward with residual connection
ffn_output = self.ffn(x)
x = x + self.dropout2(ffn_output)
x = self.norm2(x)
return x
class Encoder(nn.Module):
"""Stack of N encoder layers"""
def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout=0.1, max_len=5000):
"""
Args:
vocab_size: Size of vocabulary
d_model: Model dimension
num_layers: Number of encoder layers (N)
num_heads: Number of attention heads
d_ff: Feed-forward dimension
dropout: Dropout probability
max_len: Maximum sequence length for positional encoding
"""
super(Encoder, self).__init__()
self.d_model = d_model
# Token embedding
self.embedding = nn.Embedding(vocab_size, d_model)
# Positional encoding
self.pos_encoding = self.create_positional_encoding(max_len, d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
# Stack of encoder layers
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
def create_positional_encoding(self, max_len, d_model):
"""Create sinusoidal positional encodings"""
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe.unsqueeze(0) # (1, max_len, d_model)
def forward(self, src, src_mask=None):
"""
Args:
src: Source tokens (batch, src_len)
src_mask: Source mask
Returns:
output: Encoder output (batch, src_len, d_model)
"""
batch_size, seq_len = src.size()
# Embedding + positional encoding
x = self.embedding(src) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
x = x + self.pos_encoding[:, :seq_len, :].to(src.device)
x = self.dropout(x)
# Pass through encoder layers
for layer in self.layers:
x = layer(x, src_mask)
return x
# Example usage
vocab_size = 10000
d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
encoder = Encoder(vocab_size, d_model, num_layers, num_heads, d_ff)
# Sample input
batch_size = 2
src_len = 10
src = torch.randint(0, vocab_size, (batch_size, src_len))
# Forward pass
encoder_output = encoder(src)
print("Encoder output shape:", encoder_output.shape) # (2, 10, 512)
What the Encoder Learns
The encoder creates contextualized representations where each token's embedding incorporates information from the entire sequence.
# Input: "The cat sat on the mat"
# Token embeddings (before encoder): independent, context-free
# Encoder output: each token embedding now contains context
# For example, "mat" in encoder output understands:
# - It's a noun (from surrounding words)
# - It's related to "sat" (positional relationship)
# - It's associated with "cat" (semantic relationship)
Bidirectional Context:
The encoder can see the entire input simultaneously:
- "mat" sees both "The cat sat on the" (left context)
- AND the full sentence (right context if longer)
This bidirectionality makes encoders excellent for understanding tasks (classification, NER, etc.)
The Decoder
The decoder generates output tokens sequentially, using both self-attention and cross-attention.
Decoder Layer Structure
Each decoder layer has three sub-layers:
- Masked Multi-Head Self-Attention (on target sequence)
- Multi-Head Cross-Attention (attending to encoder output)
- Feed-Forward Network
Complete Decoder Implementation
class DecoderLayer(nn.Module):
"""Single decoder layer from 'Attention Is All You Need'"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(DecoderLayer, self).__init__()
# Masked self-attention (for target sequence)
self.self_attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
# Cross-attention (attending to encoder output)
self.cross_attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
# Feed-forward network
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
# Layer normalization
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
# Dropout
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(self, x, encoder_output, tgt_mask=None, src_mask=None):
"""
Args:
x: Decoder input (batch, tgt_len, d_model)
encoder_output: Encoder output (batch, src_len, d_model)
tgt_mask: Causal mask for target sequence
src_mask: Mask for source sequence
Returns:
output: (batch, tgt_len, d_model)
"""
# 1. Masked self-attention on target sequence
attn_output, _ = self.self_attn(
x, x, x,
attn_mask=tgt_mask,
need_weights=False
)
x = x + self.dropout1(attn_output)
x = self.norm1(x)
# 2. Cross-attention: attend to encoder output
# Query from decoder, Key & Value from encoder
attn_output, _ = self.cross_attn(
x, encoder_output, encoder_output,
attn_mask=src_mask,
need_weights=False
)
x = x + self.dropout2(attn_output)
x = self.norm2(x)
# 3. Feed-forward
ffn_output = self.ffn(x)
x = x + self.dropout3(ffn_output)
x = self.norm3(x)
return x
class Decoder(nn.Module):
"""Stack of N decoder layers"""
def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout=0.1, max_len=5000):
super(Decoder, self).__init__()
self.d_model = d_model
# Token embedding
self.embedding = nn.Embedding(vocab_size, d_model)
# Positional encoding
self.pos_encoding = self.create_positional_encoding(max_len, d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
# Stack of decoder layers
self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
# Output projection
self.output_proj = nn.Linear(d_model, vocab_size)
def create_positional_encoding(self, max_len, d_model):
"""Create sinusoidal positional encodings"""
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe.unsqueeze(0)
def forward(self, tgt, encoder_output, tgt_mask=None, src_mask=None):
"""
Args:
tgt: Target tokens (batch, tgt_len)
encoder_output: Encoder output (batch, src_len, d_model)
tgt_mask: Causal mask for target
src_mask: Mask for source
Returns:
output: Logits (batch, tgt_len, vocab_size)
"""
batch_size, seq_len = tgt.size()
# Embedding + positional encoding
x = self.embedding(tgt) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
x = x + self.pos_encoding[:, :seq_len, :].to(tgt.device)
x = self.dropout(x)
# Pass through decoder layers
for layer in self.layers:
x = layer(x, encoder_output, tgt_mask, src_mask)
# Project to vocabulary
output = self.output_proj(x)
return output
# Example usage
tgt_vocab_size = 10000
decoder = Decoder(tgt_vocab_size, d_model, num_layers, num_heads, d_ff)
# Sample input
tgt_len = 8
tgt = torch.randint(0, tgt_vocab_size, (batch_size, tgt_len))
# Create causal mask
causal_mask = torch.triu(torch.ones(tgt_len, tgt_len) * float('-inf'), diagonal=1)
# Forward pass
decoder_output = decoder(tgt, encoder_output, tgt_mask=causal_mask)
print("Decoder output shape:", decoder_output.shape) # (2, 8, 10000)
Masked Self-Attention
Causal Mask: A triangular mask applied during attention that prevents each position from attending to future positions, ensuring the model can only use information from the current and previous positions (required for autoregressive generation).
The decoder uses a causal mask to prevent positions from attending to future tokens:
def create_causal_mask(seq_len):
"""
Create causal (look-ahead) mask
Returns lower triangular matrix:
[[0, -inf, -inf],
[0, 0, -inf],
[0, 0, 0]]
"""
mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1)
return mask
# Example
mask = create_causal_mask(5)
print("Causal mask:")
print(mask)
# Interpretation:
# Position 0: can only see position 0
# Position 1: can see positions 0-1
# Position 2: can see positions 0-2
# etc.
Why Masking?
During training, we have the full target sequence. Without masking, the decoder could "cheat" by seeing future tokens.
Example (translation training):
- Target: "Le chat dort" (The cat sleeps)
- Without mask: predicting "chat" can see "dort"
- With mask: predicting "chat" only sees "Le" (realistic)
This ensures the model learns to generate sequentially, matching inference conditions.
Cross-Attention
Cross-attention allows the decoder to focus on relevant parts of the input:
# Cross-attention flow:
# 1. Query (Q): "What do I need from the source?" ← from decoder
# 2. Key (K): "What does the source offer?" ← from encoder
# 3. Value (V): "Source information to retrieve" ← from encoder
# Example: Translating "The cat sleeps" → "Le chat dort"
# When generating "chat":
# - Decoder queries what it needs
# - High attention to encoder's "cat" representation
# - Retrieves "cat" information to generate "chat"
Complete Transformer
Combining encoder and decoder:
class Transformer(nn.Module):
"""Complete Transformer (Encoder-Decoder)"""
def __init__(
self,
src_vocab_size,
tgt_vocab_size,
d_model=512,
num_layers=6,
num_heads=8,
d_ff=2048,
dropout=0.1,
max_len=5000
):
super(Transformer, self).__init__()
# Encoder
self.encoder = Encoder(
src_vocab_size, d_model, num_layers,
num_heads, d_ff, dropout, max_len
)
# Decoder
self.decoder = Decoder(
tgt_vocab_size, d_model, num_layers,
num_heads, d_ff, dropout, max_len
)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
"""
Args:
src: Source tokens (batch, src_len)
tgt: Target tokens (batch, tgt_len)
src_mask: Source attention mask
tgt_mask: Target causal mask
Returns:
output: Logits (batch, tgt_len, tgt_vocab_size)
"""
# Encode source
encoder_output = self.encoder(src, src_mask)
# Decode to generate target
decoder_output = self.decoder(tgt, encoder_output, tgt_mask, src_mask)
return decoder_output
def encode(self, src, src_mask=None):
"""Encode source sequence"""
return self.encoder(src, src_mask)
def decode(self, tgt, encoder_output, tgt_mask=None, src_mask=None):
"""Decode given encoder output"""
return self.decoder(tgt, encoder_output, tgt_mask, src_mask)
# Create model
src_vocab_size = 10000
tgt_vocab_size = 8000
model = Transformer(src_vocab_size, tgt_vocab_size)
# Sample data
batch_size = 2
src_len = 10
tgt_len = 8
src = torch.randint(0, src_vocab_size, (batch_size, src_len))
tgt = torch.randint(0, tgt_vocab_size, (batch_size, tgt_len))
# Create masks
tgt_mask = create_causal_mask(tgt_len)
# Forward pass
output = model(src, tgt, tgt_mask=tgt_mask)
print("Output shape:", output.shape) # (2, 8, 8000)
Training the Transformer
Training Loop
import torch.optim as optim
# Model, loss, optimizer
model = Transformer(src_vocab_size, tgt_vocab_size)
criterion = nn.CrossEntropyLoss(ignore_index=0) # Ignore padding
optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
def train_step(model, src, tgt, criterion, optimizer):
"""
Single training step
Args:
src: Source tokens (batch, src_len)
tgt: Target tokens (batch, tgt_len)
"""
model.train()
# Prepare target: input = tgt[:-1], label = tgt[1:]
tgt_input = tgt[:, :-1]
tgt_label = tgt[:, 1:]
# Create causal mask
tgt_len = tgt_input.size(1)
tgt_mask = create_causal_mask(tgt_len).to(src.device)
# Forward pass
output = model(src, tgt_input, tgt_mask=tgt_mask)
# Reshape for loss computation
output = output.reshape(-1, output.size(-1))
tgt_label = tgt_label.reshape(-1)
# Compute loss
loss = criterion(output, tgt_label)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
# Training loop (pseudo-code)
for epoch in range(num_epochs):
for batch in data_loader:
src, tgt = batch
loss = train_step(model, src, tgt, criterion, optimizer)
print(f"Loss: {loss:.4f}")
Teacher Forcing
Teacher Forcing: A training technique where the model receives ground-truth previous tokens as input (rather than its own predictions) during training, accelerating learning but creating a potential train-test mismatch.
During training, we use teacher forcing: feed ground-truth tokens as decoder input, not model predictions.
# Teacher forcing:
# Input: <BOS> Le chat dort
# Target: Le chat dort <EOS>
# The model predicts:
# Position 0: predicts "Le" given "<BOS>"
# Position 1: predicts "chat" given "<BOS> Le" (ground truth "Le")
# Position 2: predicts "dort" given "<BOS> Le chat" (ground truth)
Why Teacher Forcing?
- Faster training: Always use correct previous tokens
- Stable gradients: Errors don't compound
- Drawback: Train-test mismatch (at inference, model uses its own predictions)
Advanced techniques like scheduled sampling gradually mix teacher forcing with model predictions.
Inference (Generation)
Autoregressive Generation: A text generation process where the model produces one token at a time, feeding each newly generated token back as input to generate the next token, continuing until an end condition is met.
At inference, we generate one token at a time:
def generate(model, src, max_len=50, start_token=1, end_token=2):
"""
Autoregressive generation
Args:
model: Trained transformer
src: Source tokens (batch, src_len)
max_len: Maximum generation length
start_token: <BOS> token ID
end_token: <EOS> token ID
Returns:
generated: Generated token IDs (batch, gen_len)
"""
model.eval()
batch_size = src.size(0)
# Encode source once
with torch.no_grad():
encoder_output = model.encode(src)
# Initialize target with <BOS>
tgt = torch.full((batch_size, 1), start_token, dtype=torch.long)
for _ in range(max_len - 1):
# Create causal mask
tgt_len = tgt.size(1)
tgt_mask = create_causal_mask(tgt_len).to(src.device)
# Decode
with torch.no_grad():
output = model.decode(tgt, encoder_output, tgt_mask=tgt_mask)
# Get next token (greedy)
next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
# Append to sequence
tgt = torch.cat([tgt, next_token], dim=1)
# Stop if all sequences generated <EOS>
if (next_token == end_token).all():
break
return tgt
# Example generation
src = torch.randint(0, src_vocab_size, (1, 10))
generated = generate(model, src)
print("Generated sequence:", generated)
Decoding Strategies
1. Greedy Decoding
next_token = output.argmax(dim=-1)
- Simple, fast
- May miss better sequences
2. Beam Search
# Keep top-k hypotheses at each step
# Explore multiple paths simultaneously
- Better quality
- More computation
3. Sampling
probs = F.softmax(output / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
- More diverse
- Controllable via temperature
Encoder-Only vs Decoder-Only vs Encoder-Decoder
Comparison
| Architecture | Examples | Use Cases | Key Feature |
|---|---|---|---|
| Encoder-Only | BERT, RoBERTa | Classification, NER, QA | Bidirectional context |
| Decoder-Only | GPT, LLaMA, Claude | Text generation, Chat | Autoregressive generation |
| Encoder-Decoder | T5, BART, mT5 | Translation, Summarization | Separate comprehension & generation |
When to Use Each
Encoder-Only:
# Good for: Understanding tasks
# Example: Sentiment classification
input: "This movie is great!"
output: Positive (0.95)
Decoder-Only:
# Good for: Open-ended generation
# Example: Continue text
input: "Once upon a time"
output: "there was a brave knight who..."
Encoder-Decoder:
# Good for: Constrained transformation
# Example: Translation
input: "Hello, how are you?"
output: "Bonjour, comment allez-vous?"
Modern Trend:
Decoder-only models (GPT-style) have become dominant because:
- Simpler architecture (one component)
- Scale better to huge sizes
- Can do both understanding and generation
- Easier to train with next-token prediction
However, encoder-decoder models still excel at tasks requiring explicit input-output separation.
Summary
The encoder-decoder transformer architecture consists of:
Encoder:
- Processes entire input in parallel
- Bidirectional self-attention
- Creates contextualized representations
Decoder:
- Generates output autoregressively
- Masked self-attention (causal)
- Cross-attention to encoder
- One token at a time during inference
Key Mechanisms:
- Self-attention: understand within sequence
- Cross-attention: connect input to output
- Masking: prevent future information leakage
- Residual connections: enable deep networks
- Layer normalization: stabilize training
This architecture revolutionized seq2seq tasks and inspired all modern transformer variants.