Back
intermediate
Foundation of Transformers

Encoder-Decoder Architecture

Comprehensive guide to the original transformer's encoder-decoder architecture. Learn how encoders process input and decoders generate output in sequence-to-sequence tasks.

20 min read· Transformer· Encoder· Decoder· Architecture

Encoder-Decoder Architecture

The original transformer uses an encoder-decoder architecture for sequence-to-sequence tasks like translation. This lesson explains how both components work together and how they differ from encoder-only (BERT) and decoder-only (GPT) models.

The Encoder-Decoder Paradigm

Sequence-to-Sequence (Seq2Seq): A neural network architecture pattern that transforms one sequence into another potentially different-length sequence, using an encoder to compress the input and a decoder to generate the output.

Sequence-to-Sequence Tasks

Some tasks require mapping one sequence to another:

  • Machine Translation: English → French
  • Summarization: Long document → Summary
  • Speech Recognition: Audio → Text
  • Image Captioning: Image → Text description

The Two-Stage Process

Encoder: Understands the input

  • Processes source sequence
  • Creates rich representations
  • Bidirectional (can see entire input)

Decoder: Generates the output

  • Creates target sequence one token at a time
  • Attends to encoder outputs
  • Autoregressive (sees only previous outputs)

Encoder vs Decoder:

  • Encoder: "What does this mean?" (comprehension)
  • Decoder: "How do I say this?" (generation)

Think of translation: the encoder understands the English sentence, the decoder writes the French equivalent.

The Encoder

The encoder processes the entire input sequence to create contextual representations.

Encoder Layer Structure

Each encoder layer has two sub-layers:

  1. Multi-Head Self-Attention
  2. Feed-Forward Network

Both use residual connections and layer normalization.

Complete Encoder Implementation

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class EncoderLayer(nn.Module):
    """Single encoder layer from 'Attention Is All You Need'"""

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        """
        Args:
            d_model: Model dimension (e.g., 512)
            num_heads: Number of attention heads (e.g., 8)
            d_ff: Feed-forward inner dimension (e.g., 2048)
            dropout: Dropout probability
        """
        super(EncoderLayer, self).__init__()

        # Multi-head self-attention
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )

        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )

        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, src_mask=None):
        """
        Args:
            x: Input (batch, seq_len, d_model)
            src_mask: Attention mask for source

        Returns:
            output: (batch, seq_len, d_model)
        """
        # Self-attention with residual connection
        attn_output, _ = self.self_attn(
            x, x, x,
            attn_mask=src_mask,
            need_weights=False
        )
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)

        # Feed-forward with residual connection
        ffn_output = self.ffn(x)
        x = x + self.dropout2(ffn_output)
        x = self.norm2(x)

        return x


class Encoder(nn.Module):
    """Stack of N encoder layers"""

    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout=0.1, max_len=5000):
        """
        Args:
            vocab_size: Size of vocabulary
            d_model: Model dimension
            num_layers: Number of encoder layers (N)
            num_heads: Number of attention heads
            d_ff: Feed-forward dimension
            dropout: Dropout probability
            max_len: Maximum sequence length for positional encoding
        """
        super(Encoder, self).__init__()

        self.d_model = d_model

        # Token embedding
        self.embedding = nn.Embedding(vocab_size, d_model)

        # Positional encoding
        self.pos_encoding = self.create_positional_encoding(max_len, d_model)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # Stack of encoder layers
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

    def create_positional_encoding(self, max_len, d_model):
        """Create sinusoidal positional encodings"""
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        return pe.unsqueeze(0)  # (1, max_len, d_model)

    def forward(self, src, src_mask=None):
        """
        Args:
            src: Source tokens (batch, src_len)
            src_mask: Source mask

        Returns:
            output: Encoder output (batch, src_len, d_model)
        """
        batch_size, seq_len = src.size()

        # Embedding + positional encoding
        x = self.embedding(src) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        x = x + self.pos_encoding[:, :seq_len, :].to(src.device)
        x = self.dropout(x)

        # Pass through encoder layers
        for layer in self.layers:
            x = layer(x, src_mask)

        return x


# Example usage
vocab_size = 10000
d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048

encoder = Encoder(vocab_size, d_model, num_layers, num_heads, d_ff)

# Sample input
batch_size = 2
src_len = 10
src = torch.randint(0, vocab_size, (batch_size, src_len))

# Forward pass
encoder_output = encoder(src)
print("Encoder output shape:", encoder_output.shape)  # (2, 10, 512)

What the Encoder Learns

The encoder creates contextualized representations where each token's embedding incorporates information from the entire sequence.

python
# Input: "The cat sat on the mat"
# Token embeddings (before encoder): independent, context-free
# Encoder output: each token embedding now contains context

# For example, "mat" in encoder output understands:
# - It's a noun (from surrounding words)
# - It's related to "sat" (positional relationship)
# - It's associated with "cat" (semantic relationship)

Bidirectional Context:

The encoder can see the entire input simultaneously:

  • "mat" sees both "The cat sat on the" (left context)
  • AND the full sentence (right context if longer)

This bidirectionality makes encoders excellent for understanding tasks (classification, NER, etc.)

The Decoder

The decoder generates output tokens sequentially, using both self-attention and cross-attention.

Decoder Layer Structure

Each decoder layer has three sub-layers:

  1. Masked Multi-Head Self-Attention (on target sequence)
  2. Multi-Head Cross-Attention (attending to encoder output)
  3. Feed-Forward Network

Complete Decoder Implementation

python
class DecoderLayer(nn.Module):
    """Single decoder layer from 'Attention Is All You Need'"""

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()

        # Masked self-attention (for target sequence)
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )

        # Cross-attention (attending to encoder output)
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )

        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )

        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, encoder_output, tgt_mask=None, src_mask=None):
        """
        Args:
            x: Decoder input (batch, tgt_len, d_model)
            encoder_output: Encoder output (batch, src_len, d_model)
            tgt_mask: Causal mask for target sequence
            src_mask: Mask for source sequence

        Returns:
            output: (batch, tgt_len, d_model)
        """
        # 1. Masked self-attention on target sequence
        attn_output, _ = self.self_attn(
            x, x, x,
            attn_mask=tgt_mask,
            need_weights=False
        )
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)

        # 2. Cross-attention: attend to encoder output
        # Query from decoder, Key & Value from encoder
        attn_output, _ = self.cross_attn(
            x, encoder_output, encoder_output,
            attn_mask=src_mask,
            need_weights=False
        )
        x = x + self.dropout2(attn_output)
        x = self.norm2(x)

        # 3. Feed-forward
        ffn_output = self.ffn(x)
        x = x + self.dropout3(ffn_output)
        x = self.norm3(x)

        return x


class Decoder(nn.Module):
    """Stack of N decoder layers"""

    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout=0.1, max_len=5000):
        super(Decoder, self).__init__()

        self.d_model = d_model

        # Token embedding
        self.embedding = nn.Embedding(vocab_size, d_model)

        # Positional encoding
        self.pos_encoding = self.create_positional_encoding(max_len, d_model)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # Stack of decoder layers
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        # Output projection
        self.output_proj = nn.Linear(d_model, vocab_size)

    def create_positional_encoding(self, max_len, d_model):
        """Create sinusoidal positional encodings"""
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        return pe.unsqueeze(0)

    def forward(self, tgt, encoder_output, tgt_mask=None, src_mask=None):
        """
        Args:
            tgt: Target tokens (batch, tgt_len)
            encoder_output: Encoder output (batch, src_len, d_model)
            tgt_mask: Causal mask for target
            src_mask: Mask for source

        Returns:
            output: Logits (batch, tgt_len, vocab_size)
        """
        batch_size, seq_len = tgt.size()

        # Embedding + positional encoding
        x = self.embedding(tgt) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        x = x + self.pos_encoding[:, :seq_len, :].to(tgt.device)
        x = self.dropout(x)

        # Pass through decoder layers
        for layer in self.layers:
            x = layer(x, encoder_output, tgt_mask, src_mask)

        # Project to vocabulary
        output = self.output_proj(x)

        return output


# Example usage
tgt_vocab_size = 10000
decoder = Decoder(tgt_vocab_size, d_model, num_layers, num_heads, d_ff)

# Sample input
tgt_len = 8
tgt = torch.randint(0, tgt_vocab_size, (batch_size, tgt_len))

# Create causal mask
causal_mask = torch.triu(torch.ones(tgt_len, tgt_len) * float('-inf'), diagonal=1)

# Forward pass
decoder_output = decoder(tgt, encoder_output, tgt_mask=causal_mask)
print("Decoder output shape:", decoder_output.shape)  # (2, 8, 10000)

Masked Self-Attention

Causal Mask: A triangular mask applied during attention that prevents each position from attending to future positions, ensuring the model can only use information from the current and previous positions (required for autoregressive generation).

The decoder uses a causal mask to prevent positions from attending to future tokens:

python
def create_causal_mask(seq_len):
    """
    Create causal (look-ahead) mask

    Returns lower triangular matrix:
    [[0, -inf, -inf],
     [0,    0, -inf],
     [0,    0,    0]]
    """
    mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1)
    return mask


# Example
mask = create_causal_mask(5)
print("Causal mask:")
print(mask)

# Interpretation:
# Position 0: can only see position 0
# Position 1: can see positions 0-1
# Position 2: can see positions 0-2
# etc.

Why Masking?

During training, we have the full target sequence. Without masking, the decoder could "cheat" by seeing future tokens.

Example (translation training):

  • Target: "Le chat dort" (The cat sleeps)
  • Without mask: predicting "chat" can see "dort"
  • With mask: predicting "chat" only sees "Le" (realistic)

This ensures the model learns to generate sequentially, matching inference conditions.

Cross-Attention

Cross-attention allows the decoder to focus on relevant parts of the input:

python
# Cross-attention flow:
# 1. Query (Q): "What do I need from the source?"  ← from decoder
# 2. Key (K): "What does the source offer?"        ← from encoder
# 3. Value (V): "Source information to retrieve"   ← from encoder

# Example: Translating "The cat sleeps" → "Le chat dort"
# When generating "chat":
#   - Decoder queries what it needs
#   - High attention to encoder's "cat" representation
#   - Retrieves "cat" information to generate "chat"

Complete Transformer

Combining encoder and decoder:

python
class Transformer(nn.Module):
    """Complete Transformer (Encoder-Decoder)"""

    def __init__(
        self,
        src_vocab_size,
        tgt_vocab_size,
        d_model=512,
        num_layers=6,
        num_heads=8,
        d_ff=2048,
        dropout=0.1,
        max_len=5000
    ):
        super(Transformer, self).__init__()

        # Encoder
        self.encoder = Encoder(
            src_vocab_size, d_model, num_layers,
            num_heads, d_ff, dropout, max_len
        )

        # Decoder
        self.decoder = Decoder(
            tgt_vocab_size, d_model, num_layers,
            num_heads, d_ff, dropout, max_len
        )

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        Args:
            src: Source tokens (batch, src_len)
            tgt: Target tokens (batch, tgt_len)
            src_mask: Source attention mask
            tgt_mask: Target causal mask

        Returns:
            output: Logits (batch, tgt_len, tgt_vocab_size)
        """
        # Encode source
        encoder_output = self.encoder(src, src_mask)

        # Decode to generate target
        decoder_output = self.decoder(tgt, encoder_output, tgt_mask, src_mask)

        return decoder_output

    def encode(self, src, src_mask=None):
        """Encode source sequence"""
        return self.encoder(src, src_mask)

    def decode(self, tgt, encoder_output, tgt_mask=None, src_mask=None):
        """Decode given encoder output"""
        return self.decoder(tgt, encoder_output, tgt_mask, src_mask)


# Create model
src_vocab_size = 10000
tgt_vocab_size = 8000

model = Transformer(src_vocab_size, tgt_vocab_size)

# Sample data
batch_size = 2
src_len = 10
tgt_len = 8

src = torch.randint(0, src_vocab_size, (batch_size, src_len))
tgt = torch.randint(0, tgt_vocab_size, (batch_size, tgt_len))

# Create masks
tgt_mask = create_causal_mask(tgt_len)

# Forward pass
output = model(src, tgt, tgt_mask=tgt_mask)
print("Output shape:", output.shape)  # (2, 8, 8000)

Training the Transformer

Training Loop

python
import torch.optim as optim

# Model, loss, optimizer
model = Transformer(src_vocab_size, tgt_vocab_size)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding
optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

def train_step(model, src, tgt, criterion, optimizer):
    """
    Single training step

    Args:
        src: Source tokens (batch, src_len)
        tgt: Target tokens (batch, tgt_len)
    """
    model.train()

    # Prepare target: input = tgt[:-1], label = tgt[1:]
    tgt_input = tgt[:, :-1]
    tgt_label = tgt[:, 1:]

    # Create causal mask
    tgt_len = tgt_input.size(1)
    tgt_mask = create_causal_mask(tgt_len).to(src.device)

    # Forward pass
    output = model(src, tgt_input, tgt_mask=tgt_mask)

    # Reshape for loss computation
    output = output.reshape(-1, output.size(-1))
    tgt_label = tgt_label.reshape(-1)

    # Compute loss
    loss = criterion(output, tgt_label)

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()


# Training loop (pseudo-code)
for epoch in range(num_epochs):
    for batch in data_loader:
        src, tgt = batch
        loss = train_step(model, src, tgt, criterion, optimizer)
        print(f"Loss: {loss:.4f}")

Teacher Forcing

Teacher Forcing: A training technique where the model receives ground-truth previous tokens as input (rather than its own predictions) during training, accelerating learning but creating a potential train-test mismatch.

During training, we use teacher forcing: feed ground-truth tokens as decoder input, not model predictions.

python
# Teacher forcing:
# Input:  <BOS> Le chat dort
# Target: Le chat dort <EOS>

# The model predicts:
# Position 0: predicts "Le" given "<BOS>"
# Position 1: predicts "chat" given "<BOS> Le" (ground truth "Le")
# Position 2: predicts "dort" given "<BOS> Le chat" (ground truth)

Why Teacher Forcing?

  • Faster training: Always use correct previous tokens
  • Stable gradients: Errors don't compound
  • Drawback: Train-test mismatch (at inference, model uses its own predictions)

Advanced techniques like scheduled sampling gradually mix teacher forcing with model predictions.

Inference (Generation)

Autoregressive Generation: A text generation process where the model produces one token at a time, feeding each newly generated token back as input to generate the next token, continuing until an end condition is met.

At inference, we generate one token at a time:

python
def generate(model, src, max_len=50, start_token=1, end_token=2):
    """
    Autoregressive generation

    Args:
        model: Trained transformer
        src: Source tokens (batch, src_len)
        max_len: Maximum generation length
        start_token: <BOS> token ID
        end_token: <EOS> token ID

    Returns:
        generated: Generated token IDs (batch, gen_len)
    """
    model.eval()
    batch_size = src.size(0)

    # Encode source once
    with torch.no_grad():
        encoder_output = model.encode(src)

    # Initialize target with <BOS>
    tgt = torch.full((batch_size, 1), start_token, dtype=torch.long)

    for _ in range(max_len - 1):
        # Create causal mask
        tgt_len = tgt.size(1)
        tgt_mask = create_causal_mask(tgt_len).to(src.device)

        # Decode
        with torch.no_grad():
            output = model.decode(tgt, encoder_output, tgt_mask=tgt_mask)

        # Get next token (greedy)
        next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)

        # Append to sequence
        tgt = torch.cat([tgt, next_token], dim=1)

        # Stop if all sequences generated <EOS>
        if (next_token == end_token).all():
            break

    return tgt


# Example generation
src = torch.randint(0, src_vocab_size, (1, 10))
generated = generate(model, src)
print("Generated sequence:", generated)

Decoding Strategies

1. Greedy Decoding

python
next_token = output.argmax(dim=-1)
  • Simple, fast
  • May miss better sequences

2. Beam Search

python
# Keep top-k hypotheses at each step
# Explore multiple paths simultaneously
  • Better quality
  • More computation

3. Sampling

python
probs = F.softmax(output / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
  • More diverse
  • Controllable via temperature

Encoder-Only vs Decoder-Only vs Encoder-Decoder

Comparison

ArchitectureExamplesUse CasesKey Feature
Encoder-OnlyBERT, RoBERTaClassification, NER, QABidirectional context
Decoder-OnlyGPT, LLaMA, ClaudeText generation, ChatAutoregressive generation
Encoder-DecoderT5, BART, mT5Translation, SummarizationSeparate comprehension & generation

When to Use Each

Encoder-Only:

python
# Good for: Understanding tasks
# Example: Sentiment classification
input: "This movie is great!"
output: Positive (0.95)

Decoder-Only:

python
# Good for: Open-ended generation
# Example: Continue text
input: "Once upon a time"
output: "there was a brave knight who..."

Encoder-Decoder:

python
# Good for: Constrained transformation
# Example: Translation
input: "Hello, how are you?"
output: "Bonjour, comment allez-vous?"

Modern Trend:

Decoder-only models (GPT-style) have become dominant because:

  1. Simpler architecture (one component)
  2. Scale better to huge sizes
  3. Can do both understanding and generation
  4. Easier to train with next-token prediction

However, encoder-decoder models still excel at tasks requiring explicit input-output separation.

Summary

The encoder-decoder transformer architecture consists of:

Encoder:

  • Processes entire input in parallel
  • Bidirectional self-attention
  • Creates contextualized representations

Decoder:

  • Generates output autoregressively
  • Masked self-attention (causal)
  • Cross-attention to encoder
  • One token at a time during inference

Key Mechanisms:

  • Self-attention: understand within sequence
  • Cross-attention: connect input to output
  • Masking: prevent future information leakage
  • Residual connections: enable deep networks
  • Layer normalization: stabilize training

This architecture revolutionized seq2seq tasks and inspired all modern transformer variants.