Back
intermediate
Modern LLM Architectures

Decoder-Only Models (GPT, LLaMA)

Understand why decoder-only architectures dominate modern LLMs, how causal masking enables autoregressive generation, and the success of models like GPT and LLaMA

15 min min read

Decoder-Only Models (GPT, LLaMA)

Decoder-only models have become the dominant architecture for large language models. From GPT to LLaMA to Claude, most modern LLMs use this simple yet powerful design. Let's understand why decoder-only architectures work so well and how they generate text.

Why Decoder-Only Works

Decoder-Only Architecture: A transformer architecture using only the decoder stack with causal attention, processing sequences left-to-right. Despite being designed for generation, it achieves strong performance on both generation and understanding tasks through scale.

Architectural Simplicity

python
"""
Transformer Architecture Comparison:

1. Original Transformer (Encoder-Decoder):
   - Encoder: Bidirectional attention
   - Decoder: Causal attention + cross-attention
   - Use case: Translation, seq2seq tasks
   - Complexity: Two separate stacks

2. Encoder-Only (BERT):
   - Only encoder: Bidirectional attention
   - Use case: Understanding, classification
   - Cannot generate text naturally

3. Decoder-Only (GPT):
   - Only decoder: Causal attention
   - Use case: Generation, understanding, everything!
   - Simplicity: Single stack
   - Success: Scales incredibly well
"""

import torch
import torch.nn as nn

class DecoderOnlyAdvantages:
    """Why decoder-only models dominate"""

    def __init__(self):
        self.advantages = {
            'simplicity': {
                'description': 'Single stack of transformer blocks',
                'benefit': 'Easier to implement, debug, and scale',
                'example': 'GPT uses same architecture from 117M to 175B+ params'
            },
            'unified_objective': {
                'description': 'Single pre-training task (next-token prediction)',
                'benefit': 'No need to balance multiple objectives',
                'example': 'vs BERT with MLM + NSP'
            },
            'generation_native': {
                'description': 'Designed for autoregressive generation',
                'benefit': 'Natural text generation without tricks',
                'example': 'GPT generates coherent long-form text'
            },
            'task_flexibility': {
                'description': 'Handles both understanding and generation',
                'benefit': 'Single model for all tasks',
                'example': 'GPT-3 does QA, translation, summarization, etc.'
            },
            'scaling_properties': {
                'description': 'Clean scaling laws observed',
                'benefit': 'Predictable performance improvements',
                'example': 'Loss decreases predictably with model size'
            }
        }

    def compare_architectures(self):
        """Compare different transformer architectures"""
        print("Transformer Architecture Comparison:\n")

        comparisons = {
            'Encoder-Decoder (T5)': {
                'attention_types': 'Bidirectional + Causal + Cross',
                'components': 'Encoder stack + Decoder stack',
                'parameters': '2x (separate encoder/decoder)',
                'best_for': 'Seq2seq tasks (translation)'
            },
            'Encoder-Only (BERT)': {
                'attention_types': 'Bidirectional only',
                'components': 'Encoder stack only',
                'parameters': '1x',
                'best_for': 'Understanding tasks (classification)'
            },
            'Decoder-Only (GPT)': {
                'attention_types': 'Causal only',
                'components': 'Decoder stack only',
                'parameters': '1x',
                'best_for': 'Generation + understanding (everything!)'
            }
        }

        for arch, specs in comparisons.items():
            print(f"{arch}:")
            for key, value in specs.items():
                print(f"  {key}: {value}")
            print()

# Demonstrate advantages
advantages = DecoderOnlyAdvantages()
advantages.compare_architectures()

Emergent Understanding: Despite being trained only on next-token prediction (a generation task), decoder-only models develop strong understanding capabilities. This suggests that generation and understanding are two sides of the same coin.

Causal Masking

Causal Attention: An attention mechanism where each token can only attend to itself and previous tokens (not future ones), implemented using a triangular mask. This enables autoregressive generation where the model predicts one token at a time.

The key mechanism that enables autoregressive generation.

Causal Attention Implementation

python
"""
Causal Masking: Each position can only attend to itself and previous positions

Without mask (bidirectional):
Position 0 sees: [0, 1, 2, 3]
Position 1 sees: [0, 1, 2, 3]
Position 2 sees: [0, 1, 2, 3]
Position 3 sees: [0, 1, 2, 3]

With causal mask (unidirectional):
Position 0 sees: [0]
Position 1 sees: [0, 1]
Position 2 sees: [0, 1, 2]
Position 3 sees: [0, 1, 2, 3]
"""

def create_causal_mask(seq_len):
    """
    Create causal attention mask

    Returns:
        mask: Lower triangular matrix [seq_len, seq_len]
        0 = masked (cannot attend), 1 = visible (can attend)
    """
    # Create lower triangular matrix
    mask = torch.tril(torch.ones(seq_len, seq_len))

    return mask

def visualize_causal_mask():
    """Visualize causal attention mask"""
    import matplotlib.pyplot as plt

    seq_len = 8
    mask = create_causal_mask(seq_len)

    plt.figure(figsize=(8, 8))
    plt.imshow(mask, cmap='Blues', interpolation='nearest')
    plt.xlabel('Key Position (attending to)')
    plt.ylabel('Query Position (attending from)')
    plt.title('Causal Attention Mask\n(1=can attend, 0=masked)')

    # Add grid
    for i in range(seq_len + 1):
        plt.axhline(i - 0.5, color='gray', linewidth=0.5)
        plt.axvline(i - 0.5, color='gray', linewidth=0.5)

    plt.xticks(range(seq_len))
    plt.yticks(range(seq_len))
    plt.colorbar()
    plt.tight_layout()
    plt.savefig('causal_mask.png', dpi=150)

    print("Causal mask visualization saved")
    print("\nMask matrix:")
    print(mask.int())
    print("\nInterpretation:")
    print("Position 0 can only see position 0")
    print("Position 1 can see positions 0-1")
    print("Position 2 can see positions 0-2")
    print("Position 7 can see positions 0-7")

visualize_causal_mask()

class CausalSelfAttention(nn.Module):
    """Causal self-attention mechanism"""

    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Q, K, V projections
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)

        # Output projection
        self.out_proj = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

        # Causal mask buffer (registered as buffer, not parameter)
        self.register_buffer(
            'causal_mask',
            torch.tril(torch.ones(1024, 1024)).view(1, 1, 1024, 1024)
        )

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape

        # Project Q, K, V
        Q = self.q_proj(x)  # [batch, seq_len, d_model]
        K = self.k_proj(x)
        V = self.v_proj(x)

        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        # Now: [batch, num_heads, seq_len, d_k]

        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        # scores: [batch, num_heads, seq_len, seq_len]

        # Apply causal mask
        mask = self.causal_mask[:, :, :seq_len, :seq_len]
        scores = scores.masked_fill(mask == 0, float('-inf'))

        # Softmax and dropout
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention to values
        out = torch.matmul(attn_weights, V)
        # out: [batch, num_heads, seq_len, d_k]

        # Concatenate heads
        out = out.transpose(1, 2).contiguous()
        out = out.view(batch_size, seq_len, d_model)

        # Output projection
        out = self.out_proj(out)

        return out, attn_weights

# Demonstrate causal attention
def demonstrate_causal_attention():
    """Show how causal masking affects attention"""

    attn = CausalSelfAttention(d_model=512, num_heads=8)

    # Create sample input
    batch_size, seq_len = 1, 6
    x = torch.randn(batch_size, seq_len, 512)

    # Forward pass
    output, attn_weights = attn(x)

    # Show attention pattern for one head
    print("Causal Attention Pattern (Head 0):")
    print(attn_weights[0, 0].detach().numpy())
    print("\nNotice: Each row (query) only attends to itself and previous positions")
    print("Upper triangle is all zeros (masked)")

demonstrate_causal_attention()

Why Masking Works: Causal masking prevents information leakage during training. Without it, the model could "cheat" by looking at future tokens when predicting the current token, which wouldn't be possible during generation.

Autoregressive Generation

How decoder-only models generate text one token at a time.

Generation Process

python
"""
Autoregressive Generation Process:

1. Start with prompt/context
2. Model predicts probability distribution over next token
3. Sample or select next token
4. Append to sequence
5. Repeat until done

Example:
Prompt: "The cat sat"
Step 1: "The cat sat" → predict "on" (or other options)
Step 2: "The cat sat on" → predict "the"
Step 3: "The cat sat on the" → predict "mat"
Result: "The cat sat on the mat"
"""

class AutoregressiveGenerator:
    """Autoregressive text generation"""

    def __init__(self, model, tokenizer, max_length=50):
        self.model = model
        self.tokenizer = tokenizer
        self.max_length = max_length

    @torch.no_grad()
    def generate_greedy(self, prompt, max_new_tokens=20):
        """
        Greedy decoding: Always pick most probable token

        Args:
            prompt: Input text
            max_new_tokens: Number of tokens to generate

        Returns:
            Generated text
        """
        # Tokenize prompt
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt')

        for _ in range(max_new_tokens):
            # Forward pass
            outputs = self.model(input_ids)
            logits = outputs.logits  # [batch, seq_len, vocab_size]

            # Get logits for last position
            next_token_logits = logits[0, -1, :]

            # Greedy: select most probable token
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

            # Append to sequence
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)

            # Stop if EOS token
            if next_token.item() == self.tokenizer.eos_token_id:
                break

        # Decode
        generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
        return generated_text

    @torch.no_grad()
    def generate_sampling(self, prompt, max_new_tokens=20, temperature=1.0, top_p=0.9):
        """
        Sampling decoding: Sample from probability distribution

        Args:
            prompt: Input text
            max_new_tokens: Number of tokens to generate
            temperature: Sampling temperature (higher = more random)
            top_p: Nucleus sampling threshold

        Returns:
            Generated text
        """
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt')

        for _ in range(max_new_tokens):
            outputs = self.model(input_ids)
            next_token_logits = outputs.logits[0, -1, :]

            # Apply temperature
            next_token_logits = next_token_logits / temperature

            # Top-p (nucleus) sampling
            sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)

            # Remove tokens with cumulative probability above threshold
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
            sorted_indices_to_remove[0] = 0

            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            next_token_logits[indices_to_remove] = float('-inf')

            # Sample from distribution
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)

            if next_token.item() == self.tokenizer.eos_token_id:
                break

        generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
        return generated_text

# Demonstrate generation strategies
def demonstrate_generation():
    """Show different generation strategies"""
    from transformers import GPT2LMHeadModel, GPT2Tokenizer

    model = GPT2LMHeadModel.from_pretrained('gpt2')
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

    prompt = "The future of artificial intelligence is"

    print("Autoregressive Generation Strategies:\n")
    print(f"Prompt: '{prompt}'\n")

    # Greedy decoding
    outputs = model.generate(
        tokenizer.encode(prompt, return_tensors='pt'),
        max_new_tokens=20,
        do_sample=False  # Greedy
    )
    greedy_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Greedy: {greedy_text}\n")

    # Sampling (temperature=0.7)
    outputs = model.generate(
        tokenizer.encode(prompt, return_tensors='pt'),
        max_new_tokens=20,
        do_sample=True,
        temperature=0.7
    )
    sampling_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Sampling (T=0.7): {sampling_text}\n")

    # Top-p sampling
    outputs = model.generate(
        tokenizer.encode(prompt, return_tensors='pt'),
        max_new_tokens=20,
        do_sample=True,
        top_p=0.9,
        temperature=0.8
    )
    topp_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Top-p (p=0.9): {topp_text}\n")

    # Beam search
    outputs = model.generate(
        tokenizer.encode(prompt, return_tensors='pt'),
        max_new_tokens=20,
        num_beams=4,
        early_stopping=True
    )
    beam_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Beam Search: {beam_text}")

demonstrate_generation()

KV Cache Optimization

KV Cache: A memory optimization that stores previously computed key and value tensors during generation, avoiding redundant computation of attention for past tokens and dramatically speeding up inference at the cost of increased memory usage.

python
"""
KV Cache: Optimization for autoregressive generation

Problem:
- At each step, recompute attention for entire sequence
- Wasteful: previous positions' K and V don't change

Solution:
- Cache K and V from previous steps
- Only compute K, V for new token
- Dramatically faster generation
"""

class CausalAttentionWithKVCache(nn.Module):
    """Efficient causal attention with KV caching"""

    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, past_kv=None):
        """
        Forward with optional KV cache

        Args:
            x: Input [batch, seq_len, d_model]
            past_kv: Cached (K, V) from previous steps

        Returns:
            output: Attention output
            new_kv: Updated KV cache
        """
        batch_size, seq_len, d_model = x.shape

        # Compute Q, K, V for current input
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # Reshape
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # Use cached K, V if available
        if past_kv is not None:
            past_K, past_V = past_kv
            K = torch.cat([past_K, K], dim=2)  # Concatenate along seq dimension
            V = torch.cat([past_V, V], dim=2)

        # Compute attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)

        # Causal mask (only for newly added positions)
        if past_kv is not None:
            # Only mask within new tokens
            total_len = K.size(2)
            mask = torch.tril(torch.ones(seq_len, total_len))
            mask = mask.view(1, 1, seq_len, total_len)
        else:
            mask = torch.tril(torch.ones(seq_len, seq_len))
            mask = mask.view(1, 1, seq_len, seq_len)

        scores = scores.masked_fill(mask.to(scores.device) == 0, float('-inf'))

        attn_weights = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn_weights, V)

        # Reshape and project
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        out = self.out_proj(out)

        # Return output and new cache
        new_kv = (K, V)
        return out, new_kv

# Compare with and without KV cache
def benchmark_kv_cache():
    """Show speedup from KV caching"""
    import time

    model_no_cache = CausalSelfAttention(d_model=768, num_heads=12)
    model_with_cache = CausalAttentionWithKVCache(d_model=768, num_heads=12)

    # Simulate generation
    batch_size = 1
    prompt_len = 50
    generate_len = 50

    x_prompt = torch.randn(batch_size, prompt_len, 768)

    # Without cache: recompute everything each step
    start = time.time()
    x = x_prompt
    for i in range(generate_len):
        new_token = torch.randn(batch_size, 1, 768)
        x = torch.cat([x, new_token], dim=1)
        out, _ = model_no_cache(x)  # Recompute for entire sequence
    time_no_cache = time.time() - start

    # With cache: only compute new tokens
    start = time.time()
    out, kv_cache = model_with_cache(x_prompt, past_kv=None)
    for i in range(generate_len):
        new_token = torch.randn(batch_size, 1, 768)
        out, kv_cache = model_with_cache(new_token, past_kv=kv_cache)
    time_with_cache = time.time() - start

    print("KV Cache Benchmark:")
    print(f"Without cache: {time_no_cache:.4f}s")
    print(f"With cache: {time_with_cache:.4f}s")
    print(f"Speedup: {time_no_cache / time_with_cache:.2f}x")

benchmark_kv_cache()

Memory Trade-off: KV caching speeds up generation but increases memory usage. For very long sequences or large batch sizes, memory can become a bottleneck.

Modern Decoder-Only Models

LLaMA Architecture

python
"""
LLaMA (Large Language Model Meta AI):
Open-source decoder-only models optimized for efficiency

Key improvements over GPT:
1. Pre-normalization (like GPT-2, but with RMSNorm)
2. SwiGLU activation (instead of GELU)
3. Rotary Position Embeddings (RoPE, instead of learned)
4. Removed biases in linear layers

Models:
- LLaMA-7B: 7 billion parameters
- LLaMA-13B: 13 billion parameters
- LLaMA-33B: 33 billion parameters
- LLaMA-65B: 65 billion parameters

LLaMA 2 (2023):
- Same sizes plus 70B
- Trained on 2 trillion tokens (vs 1.4T for LLaMA 1)
- Longer context (4096 tokens)
- Commercial use allowed
"""

class RMSNorm(nn.Module):
    """
    Root Mean Square Normalization (used in LLaMA)

    Simpler and faster than LayerNorm:
    - No mean subtraction
    - No bias term
    - Only normalizes by RMS
    """

    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # Compute RMS
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        # Normalize
        x = x / rms
        # Scale
        return x * self.weight

class SwiGLU(nn.Module):
    """
    SwiGLU activation (used in LLaMA)

    Combines Swish activation with gating:
    SwiGLU(x, W, V) = Swish(xW) ⊗ xV

    Where Swish(x) = x * sigmoid(x)
    """

    def __init__(self, dim, hidden_dim=None):
        super().__init__()
        hidden_dim = hidden_dim or int(dim * 8/3)  # LLaMA uses 8/3 ratio

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        # SwiGLU: Swish(xW1) ⊗ xW3 then project back with W2
        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))

class RotaryPositionalEmbedding(nn.Module):
    """
    Rotary Position Embedding (RoPE) used in LLaMA

    Instead of adding position embeddings, RoPE rotates
    the query and key vectors based on their positions
    """

    def __init__(self, dim, max_seq_len=2048, base=10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        # Precompute frequencies
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        # Precompute cos and sin
        t = torch.arange(max_seq_len).type_as(self.inv_freq)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)

        self.register_buffer('cos_cached', emb.cos())
        self.register_buffer('sin_cached', emb.sin())

    def rotate_half(self, x):
        """Rotate half the hidden dims"""
        x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
        return torch.cat((-x2, x1), dim=-1)

    def forward(self, q, k, seq_len):
        """Apply rotary embeddings to queries and keys"""
        cos = self.cos_cached[:seq_len, ...]
        sin = self.sin_cached[:seq_len, ...]

        # Apply rotation
        q_embed = (q * cos) + (self.rotate_half(q) * sin)
        k_embed = (k * cos) + (self.rotate_half(k) * sin)

        return q_embed, k_embed

class LLaMABlock(nn.Module):
    """LLaMA transformer block"""

    def __init__(self, dim, num_heads, multiple_of=256):
        super().__init__()

        # Attention
        self.attention_norm = RMSNorm(dim)
        self.attention = CausalSelfAttention(dim, num_heads)

        # Feed-forward
        self.ffn_norm = RMSNorm(dim)
        self.feed_forward = SwiGLU(dim)

    def forward(self, x):
        # Attention block with residual
        h = x + self.attention(self.attention_norm(x))[0]

        # FFN block with residual
        out = h + self.feed_forward(self.ffn_norm(h))

        return out

# Compare normalizations
def compare_normalizations():
    """Compare LayerNorm vs RMSNorm"""

    dim = 768
    x = torch.randn(2, 10, dim)

    # LayerNorm
    ln = nn.LayerNorm(dim)
    ln_out = ln(x)

    # RMSNorm
    rms = RMSNorm(dim)
    rms_out = rms(x)

    print("Normalization Comparison:")
    print(f"Input mean: {x.mean():.6f}, std: {x.std():.6f}")
    print(f"LayerNorm mean: {ln_out.mean():.6f}, std: {ln_out.std():.6f}")
    print(f"RMSNorm mean: {rms_out.mean():.6f}, std: {rms_out.std():.6f}")
    print("\nRMSNorm doesn't center (mean ≠ 0) but normalizes scale")

compare_normalizations()

LLaMA's Impact: By releasing model weights openly, LLaMA democratized access to large language models and sparked a wave of innovation in fine-tuning, quantization, and efficient deployment.

Why Decoder-Only Dominates

python
"""
Why Decoder-Only Models Won:

1. SIMPLICITY SCALES:
   - Single architecture for all tasks
   - Easier to optimize at scale
   - Fewer hyperparameters to tune

2. NEXT-TOKEN PREDICTION IS ENOUGH:
   - Simple objective, powerful results
   - Learns both understanding and generation
   - No need for complex multi-task objectives

3. IN-CONTEXT LEARNING:
   - Can adapt to new tasks from examples in context
   - No fine-tuning needed (for large models)
   - More flexible than task-specific models

4. GENERATION IS KING:
   - Most useful applications involve generation
   - Chat, code completion, writing assistance
   - Encoder-only models can't do this naturally

5. EMPIRICAL SUCCESS:
   - GPT-3 showed massive scale works
   - Scaling laws are clean and predictable
   - Industry standardized on decoder-only
"""

# Scaling laws for decoder-only models
def plot_scaling_laws():
    """Visualize scaling laws for decoder-only models"""
    import matplotlib.pyplot as plt
    import numpy as np

    # Parameters (in billions)
    params = np.array([0.125, 0.35, 1.3, 6.7, 13, 175])

    # Approximate loss (from scaling law papers)
    # Loss ≈ (N/N0)^(-α) where α ≈ 0.076
    N0 = 8.8e9
    alpha = 0.076
    loss = (params * 1e9 / N0) ** (-alpha) * 2.5 + 1.5

    # Downstream task performance (approximate)
    task_perf = 100 * (1 - np.exp(-params / 50))

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    # Plot 1: Training loss vs parameters
    ax1.plot(params, loss, 'o-', linewidth=2, markersize=8)
    ax1.set_xscale('log')
    ax1.set_xlabel('Parameters (billions)')
    ax1.set_ylabel('Training Loss')
    ax1.set_title('Scaling Law: Loss Decreases Predictably')
    ax1.grid(alpha=0.3)

    # Plot 2: Task performance vs parameters
    ax2.plot(params, task_perf, 's-', linewidth=2, markersize=8, color='green')
    ax2.set_xscale('log')
    ax2.set_xlabel('Parameters (billions)')
    ax2.set_ylabel('Task Performance')
    ax2.set_title('Downstream Task Performance vs Model Size')
    ax2.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig('decoder_scaling_laws.png', dpi=150)

    print("Scaling laws visualization saved")
    print("\nKey insight: Performance improves predictably with scale")

plot_scaling_laws()

# Compare model families
decoder_models_timeline = """
Decoder-Only Models Timeline:

2018: GPT (117M) - Proves pre-training + fine-tuning works
2019: GPT-2 (1.5B) - Zero-shot task transfer emerges
2020: GPT-3 (175B) - Few-shot learning, massive scale
2021: GPT-J (6B), GPT-NeoX (20B) - Open-source alternatives
2023: LLaMA (7B-65B) - Efficient, open-source
2023: LLaMA 2 (7B-70B) - Commercially usable
2023: Mistral (7B) - Sparse mixture of experts
2024: Many more...

Common thread: All decoder-only!
"""

print(decoder_models_timeline)

Practice Exercise

python
# Exercise: Implement a simple decoder-only model
class SimpleDecoderLM(nn.Module):
    """
    Minimal decoder-only language model

    Exercise: Complete the missing parts
    """

    def __init__(self, vocab_size, d_model=512, num_layers=6, num_heads=8):
        super().__init__()

        # TODO: Add token embeddings
        self.token_embed = nn.Embedding(vocab_size, d_model)

        # TODO: Add position embeddings
        self.pos_embed = nn.Embedding(1024, d_model)

        # TODO: Add transformer blocks
        self.blocks = nn.ModuleList([
            LLaMABlock(d_model, num_heads) for _ in range(num_layers)
        ])

        # TODO: Add output head
        self.output_norm = RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids):
        batch_size, seq_len = input_ids.shape

        # TODO: Compute embeddings
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        x = self.token_embed(input_ids) + self.pos_embed(positions)

        # TODO: Apply transformer blocks
        for block in self.blocks:
            x = block(x)

        # TODO: Compute logits
        x = self.output_norm(x)
        logits = self.lm_head(x)

        return logits

# Test the model
model = SimpleDecoderLM(vocab_size=50000, d_model=512, num_layers=6)
dummy_input = torch.randint(0, 50000, (2, 10))  # batch_size=2, seq_len=10
output = model(dummy_input)

print(f"Model output shape: {output.shape}")  # Should be [2, 10, 50000]
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# Exercise questions
exercise_questions = """
Practice Exercises:

1. Why is causal masking necessary during training but not during
   generation? What would happen without it?

2. Implement a function to compute the memory savings from KV caching
   when generating 100 tokens from a 50-token prompt.

3. Compare: Calculate FLOPs for one forward pass with and without
   KV caching for a 12-layer model generating 50 tokens.

4. Design: How would you modify the decoder architecture to handle
   2x longer context efficiently?

5. Explain: Why does RoPE (rotary embeddings) generalize better to
   longer sequences than learned position embeddings?
"""

print(exercise_questions)

Quiz

Further Reading