Decoder-Only Models (GPT, LLaMA)
Decoder-only models have become the dominant architecture for large language models. From GPT to LLaMA to Claude, most modern LLMs use this simple yet powerful design. Let's understand why decoder-only architectures work so well and how they generate text.
Why Decoder-Only Works
Decoder-Only Architecture: A transformer architecture using only the decoder stack with causal attention, processing sequences left-to-right. Despite being designed for generation, it achieves strong performance on both generation and understanding tasks through scale.
Architectural Simplicity
"""
Transformer Architecture Comparison:
1. Original Transformer (Encoder-Decoder):
- Encoder: Bidirectional attention
- Decoder: Causal attention + cross-attention
- Use case: Translation, seq2seq tasks
- Complexity: Two separate stacks
2. Encoder-Only (BERT):
- Only encoder: Bidirectional attention
- Use case: Understanding, classification
- Cannot generate text naturally
3. Decoder-Only (GPT):
- Only decoder: Causal attention
- Use case: Generation, understanding, everything!
- Simplicity: Single stack
- Success: Scales incredibly well
"""
import torch
import torch.nn as nn
class DecoderOnlyAdvantages:
"""Why decoder-only models dominate"""
def __init__(self):
self.advantages = {
'simplicity': {
'description': 'Single stack of transformer blocks',
'benefit': 'Easier to implement, debug, and scale',
'example': 'GPT uses same architecture from 117M to 175B+ params'
},
'unified_objective': {
'description': 'Single pre-training task (next-token prediction)',
'benefit': 'No need to balance multiple objectives',
'example': 'vs BERT with MLM + NSP'
},
'generation_native': {
'description': 'Designed for autoregressive generation',
'benefit': 'Natural text generation without tricks',
'example': 'GPT generates coherent long-form text'
},
'task_flexibility': {
'description': 'Handles both understanding and generation',
'benefit': 'Single model for all tasks',
'example': 'GPT-3 does QA, translation, summarization, etc.'
},
'scaling_properties': {
'description': 'Clean scaling laws observed',
'benefit': 'Predictable performance improvements',
'example': 'Loss decreases predictably with model size'
}
}
def compare_architectures(self):
"""Compare different transformer architectures"""
print("Transformer Architecture Comparison:\n")
comparisons = {
'Encoder-Decoder (T5)': {
'attention_types': 'Bidirectional + Causal + Cross',
'components': 'Encoder stack + Decoder stack',
'parameters': '2x (separate encoder/decoder)',
'best_for': 'Seq2seq tasks (translation)'
},
'Encoder-Only (BERT)': {
'attention_types': 'Bidirectional only',
'components': 'Encoder stack only',
'parameters': '1x',
'best_for': 'Understanding tasks (classification)'
},
'Decoder-Only (GPT)': {
'attention_types': 'Causal only',
'components': 'Decoder stack only',
'parameters': '1x',
'best_for': 'Generation + understanding (everything!)'
}
}
for arch, specs in comparisons.items():
print(f"{arch}:")
for key, value in specs.items():
print(f" {key}: {value}")
print()
# Demonstrate advantages
advantages = DecoderOnlyAdvantages()
advantages.compare_architectures()
Emergent Understanding: Despite being trained only on next-token prediction (a generation task), decoder-only models develop strong understanding capabilities. This suggests that generation and understanding are two sides of the same coin.
Causal Masking
Causal Attention: An attention mechanism where each token can only attend to itself and previous tokens (not future ones), implemented using a triangular mask. This enables autoregressive generation where the model predicts one token at a time.
The key mechanism that enables autoregressive generation.
Causal Attention Implementation
"""
Causal Masking: Each position can only attend to itself and previous positions
Without mask (bidirectional):
Position 0 sees: [0, 1, 2, 3]
Position 1 sees: [0, 1, 2, 3]
Position 2 sees: [0, 1, 2, 3]
Position 3 sees: [0, 1, 2, 3]
With causal mask (unidirectional):
Position 0 sees: [0]
Position 1 sees: [0, 1]
Position 2 sees: [0, 1, 2]
Position 3 sees: [0, 1, 2, 3]
"""
def create_causal_mask(seq_len):
"""
Create causal attention mask
Returns:
mask: Lower triangular matrix [seq_len, seq_len]
0 = masked (cannot attend), 1 = visible (can attend)
"""
# Create lower triangular matrix
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask
def visualize_causal_mask():
"""Visualize causal attention mask"""
import matplotlib.pyplot as plt
seq_len = 8
mask = create_causal_mask(seq_len)
plt.figure(figsize=(8, 8))
plt.imshow(mask, cmap='Blues', interpolation='nearest')
plt.xlabel('Key Position (attending to)')
plt.ylabel('Query Position (attending from)')
plt.title('Causal Attention Mask\n(1=can attend, 0=masked)')
# Add grid
for i in range(seq_len + 1):
plt.axhline(i - 0.5, color='gray', linewidth=0.5)
plt.axvline(i - 0.5, color='gray', linewidth=0.5)
plt.xticks(range(seq_len))
plt.yticks(range(seq_len))
plt.colorbar()
plt.tight_layout()
plt.savefig('causal_mask.png', dpi=150)
print("Causal mask visualization saved")
print("\nMask matrix:")
print(mask.int())
print("\nInterpretation:")
print("Position 0 can only see position 0")
print("Position 1 can see positions 0-1")
print("Position 2 can see positions 0-2")
print("Position 7 can see positions 0-7")
visualize_causal_mask()
class CausalSelfAttention(nn.Module):
"""Causal self-attention mechanism"""
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Q, K, V projections
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
# Output projection
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
# Causal mask buffer (registered as buffer, not parameter)
self.register_buffer(
'causal_mask',
torch.tril(torch.ones(1024, 1024)).view(1, 1, 1024, 1024)
)
def forward(self, x):
batch_size, seq_len, d_model = x.shape
# Project Q, K, V
Q = self.q_proj(x) # [batch, seq_len, d_model]
K = self.k_proj(x)
V = self.v_proj(x)
# Reshape for multi-head attention
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Now: [batch, num_heads, seq_len, d_k]
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
# scores: [batch, num_heads, seq_len, seq_len]
# Apply causal mask
mask = self.causal_mask[:, :, :seq_len, :seq_len]
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax and dropout
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# Apply attention to values
out = torch.matmul(attn_weights, V)
# out: [batch, num_heads, seq_len, d_k]
# Concatenate heads
out = out.transpose(1, 2).contiguous()
out = out.view(batch_size, seq_len, d_model)
# Output projection
out = self.out_proj(out)
return out, attn_weights
# Demonstrate causal attention
def demonstrate_causal_attention():
"""Show how causal masking affects attention"""
attn = CausalSelfAttention(d_model=512, num_heads=8)
# Create sample input
batch_size, seq_len = 1, 6
x = torch.randn(batch_size, seq_len, 512)
# Forward pass
output, attn_weights = attn(x)
# Show attention pattern for one head
print("Causal Attention Pattern (Head 0):")
print(attn_weights[0, 0].detach().numpy())
print("\nNotice: Each row (query) only attends to itself and previous positions")
print("Upper triangle is all zeros (masked)")
demonstrate_causal_attention()
Why Masking Works: Causal masking prevents information leakage during training. Without it, the model could "cheat" by looking at future tokens when predicting the current token, which wouldn't be possible during generation.
Autoregressive Generation
How decoder-only models generate text one token at a time.
Generation Process
"""
Autoregressive Generation Process:
1. Start with prompt/context
2. Model predicts probability distribution over next token
3. Sample or select next token
4. Append to sequence
5. Repeat until done
Example:
Prompt: "The cat sat"
Step 1: "The cat sat" → predict "on" (or other options)
Step 2: "The cat sat on" → predict "the"
Step 3: "The cat sat on the" → predict "mat"
Result: "The cat sat on the mat"
"""
class AutoregressiveGenerator:
"""Autoregressive text generation"""
def __init__(self, model, tokenizer, max_length=50):
self.model = model
self.tokenizer = tokenizer
self.max_length = max_length
@torch.no_grad()
def generate_greedy(self, prompt, max_new_tokens=20):
"""
Greedy decoding: Always pick most probable token
Args:
prompt: Input text
max_new_tokens: Number of tokens to generate
Returns:
Generated text
"""
# Tokenize prompt
input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
for _ in range(max_new_tokens):
# Forward pass
outputs = self.model(input_ids)
logits = outputs.logits # [batch, seq_len, vocab_size]
# Get logits for last position
next_token_logits = logits[0, -1, :]
# Greedy: select most probable token
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Append to sequence
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
# Stop if EOS token
if next_token.item() == self.tokenizer.eos_token_id:
break
# Decode
generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
return generated_text
@torch.no_grad()
def generate_sampling(self, prompt, max_new_tokens=20, temperature=1.0, top_p=0.9):
"""
Sampling decoding: Sample from probability distribution
Args:
prompt: Input text
max_new_tokens: Number of tokens to generate
temperature: Sampling temperature (higher = more random)
top_p: Nucleus sampling threshold
Returns:
Generated text
"""
input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
for _ in range(max_new_tokens):
outputs = self.model(input_ids)
next_token_logits = outputs.logits[0, -1, :]
# Apply temperature
next_token_logits = next_token_logits / temperature
# Top-p (nucleus) sampling
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_token_logits[indices_to_remove] = float('-inf')
# Sample from distribution
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
if next_token.item() == self.tokenizer.eos_token_id:
break
generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
return generated_text
# Demonstrate generation strategies
def demonstrate_generation():
"""Show different generation strategies"""
from transformers import GPT2LMHeadModel, GPT2Tokenizer
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
prompt = "The future of artificial intelligence is"
print("Autoregressive Generation Strategies:\n")
print(f"Prompt: '{prompt}'\n")
# Greedy decoding
outputs = model.generate(
tokenizer.encode(prompt, return_tensors='pt'),
max_new_tokens=20,
do_sample=False # Greedy
)
greedy_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Greedy: {greedy_text}\n")
# Sampling (temperature=0.7)
outputs = model.generate(
tokenizer.encode(prompt, return_tensors='pt'),
max_new_tokens=20,
do_sample=True,
temperature=0.7
)
sampling_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Sampling (T=0.7): {sampling_text}\n")
# Top-p sampling
outputs = model.generate(
tokenizer.encode(prompt, return_tensors='pt'),
max_new_tokens=20,
do_sample=True,
top_p=0.9,
temperature=0.8
)
topp_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Top-p (p=0.9): {topp_text}\n")
# Beam search
outputs = model.generate(
tokenizer.encode(prompt, return_tensors='pt'),
max_new_tokens=20,
num_beams=4,
early_stopping=True
)
beam_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Beam Search: {beam_text}")
demonstrate_generation()
KV Cache Optimization
KV Cache: A memory optimization that stores previously computed key and value tensors during generation, avoiding redundant computation of attention for past tokens and dramatically speeding up inference at the cost of increased memory usage.
"""
KV Cache: Optimization for autoregressive generation
Problem:
- At each step, recompute attention for entire sequence
- Wasteful: previous positions' K and V don't change
Solution:
- Cache K and V from previous steps
- Only compute K, V for new token
- Dramatically faster generation
"""
class CausalAttentionWithKVCache(nn.Module):
"""Efficient causal attention with KV caching"""
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, past_kv=None):
"""
Forward with optional KV cache
Args:
x: Input [batch, seq_len, d_model]
past_kv: Cached (K, V) from previous steps
Returns:
output: Attention output
new_kv: Updated KV cache
"""
batch_size, seq_len, d_model = x.shape
# Compute Q, K, V for current input
Q = self.q_proj(x)
K = self.k_proj(x)
V = self.v_proj(x)
# Reshape
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Use cached K, V if available
if past_kv is not None:
past_K, past_V = past_kv
K = torch.cat([past_K, K], dim=2) # Concatenate along seq dimension
V = torch.cat([past_V, V], dim=2)
# Compute attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
# Causal mask (only for newly added positions)
if past_kv is not None:
# Only mask within new tokens
total_len = K.size(2)
mask = torch.tril(torch.ones(seq_len, total_len))
mask = mask.view(1, 1, seq_len, total_len)
else:
mask = torch.tril(torch.ones(seq_len, seq_len))
mask = mask.view(1, 1, seq_len, seq_len)
scores = scores.masked_fill(mask.to(scores.device) == 0, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
out = torch.matmul(attn_weights, V)
# Reshape and project
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
out = self.out_proj(out)
# Return output and new cache
new_kv = (K, V)
return out, new_kv
# Compare with and without KV cache
def benchmark_kv_cache():
"""Show speedup from KV caching"""
import time
model_no_cache = CausalSelfAttention(d_model=768, num_heads=12)
model_with_cache = CausalAttentionWithKVCache(d_model=768, num_heads=12)
# Simulate generation
batch_size = 1
prompt_len = 50
generate_len = 50
x_prompt = torch.randn(batch_size, prompt_len, 768)
# Without cache: recompute everything each step
start = time.time()
x = x_prompt
for i in range(generate_len):
new_token = torch.randn(batch_size, 1, 768)
x = torch.cat([x, new_token], dim=1)
out, _ = model_no_cache(x) # Recompute for entire sequence
time_no_cache = time.time() - start
# With cache: only compute new tokens
start = time.time()
out, kv_cache = model_with_cache(x_prompt, past_kv=None)
for i in range(generate_len):
new_token = torch.randn(batch_size, 1, 768)
out, kv_cache = model_with_cache(new_token, past_kv=kv_cache)
time_with_cache = time.time() - start
print("KV Cache Benchmark:")
print(f"Without cache: {time_no_cache:.4f}s")
print(f"With cache: {time_with_cache:.4f}s")
print(f"Speedup: {time_no_cache / time_with_cache:.2f}x")
benchmark_kv_cache()
Memory Trade-off: KV caching speeds up generation but increases memory usage. For very long sequences or large batch sizes, memory can become a bottleneck.
Modern Decoder-Only Models
LLaMA Architecture
"""
LLaMA (Large Language Model Meta AI):
Open-source decoder-only models optimized for efficiency
Key improvements over GPT:
1. Pre-normalization (like GPT-2, but with RMSNorm)
2. SwiGLU activation (instead of GELU)
3. Rotary Position Embeddings (RoPE, instead of learned)
4. Removed biases in linear layers
Models:
- LLaMA-7B: 7 billion parameters
- LLaMA-13B: 13 billion parameters
- LLaMA-33B: 33 billion parameters
- LLaMA-65B: 65 billion parameters
LLaMA 2 (2023):
- Same sizes plus 70B
- Trained on 2 trillion tokens (vs 1.4T for LLaMA 1)
- Longer context (4096 tokens)
- Commercial use allowed
"""
class RMSNorm(nn.Module):
"""
Root Mean Square Normalization (used in LLaMA)
Simpler and faster than LayerNorm:
- No mean subtraction
- No bias term
- Only normalizes by RMS
"""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
# Compute RMS
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
# Normalize
x = x / rms
# Scale
return x * self.weight
class SwiGLU(nn.Module):
"""
SwiGLU activation (used in LLaMA)
Combines Swish activation with gating:
SwiGLU(x, W, V) = Swish(xW) ⊗ xV
Where Swish(x) = x * sigmoid(x)
"""
def __init__(self, dim, hidden_dim=None):
super().__init__()
hidden_dim = hidden_dim or int(dim * 8/3) # LLaMA uses 8/3 ratio
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
# SwiGLU: Swish(xW1) ⊗ xW3 then project back with W2
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
class RotaryPositionalEmbedding(nn.Module):
"""
Rotary Position Embedding (RoPE) used in LLaMA
Instead of adding position embeddings, RoPE rotates
the query and key vectors based on their positions
"""
def __init__(self, dim, max_seq_len=2048, base=10000):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# Precompute frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# Precompute cos and sin
t = torch.arange(max_seq_len).type_as(self.inv_freq)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer('cos_cached', emb.cos())
self.register_buffer('sin_cached', emb.sin())
def rotate_half(self, x):
"""Rotate half the hidden dims"""
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
return torch.cat((-x2, x1), dim=-1)
def forward(self, q, k, seq_len):
"""Apply rotary embeddings to queries and keys"""
cos = self.cos_cached[:seq_len, ...]
sin = self.sin_cached[:seq_len, ...]
# Apply rotation
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed
class LLaMABlock(nn.Module):
"""LLaMA transformer block"""
def __init__(self, dim, num_heads, multiple_of=256):
super().__init__()
# Attention
self.attention_norm = RMSNorm(dim)
self.attention = CausalSelfAttention(dim, num_heads)
# Feed-forward
self.ffn_norm = RMSNorm(dim)
self.feed_forward = SwiGLU(dim)
def forward(self, x):
# Attention block with residual
h = x + self.attention(self.attention_norm(x))[0]
# FFN block with residual
out = h + self.feed_forward(self.ffn_norm(h))
return out
# Compare normalizations
def compare_normalizations():
"""Compare LayerNorm vs RMSNorm"""
dim = 768
x = torch.randn(2, 10, dim)
# LayerNorm
ln = nn.LayerNorm(dim)
ln_out = ln(x)
# RMSNorm
rms = RMSNorm(dim)
rms_out = rms(x)
print("Normalization Comparison:")
print(f"Input mean: {x.mean():.6f}, std: {x.std():.6f}")
print(f"LayerNorm mean: {ln_out.mean():.6f}, std: {ln_out.std():.6f}")
print(f"RMSNorm mean: {rms_out.mean():.6f}, std: {rms_out.std():.6f}")
print("\nRMSNorm doesn't center (mean ≠ 0) but normalizes scale")
compare_normalizations()
LLaMA's Impact: By releasing model weights openly, LLaMA democratized access to large language models and sparked a wave of innovation in fine-tuning, quantization, and efficient deployment.
Why Decoder-Only Dominates
"""
Why Decoder-Only Models Won:
1. SIMPLICITY SCALES:
- Single architecture for all tasks
- Easier to optimize at scale
- Fewer hyperparameters to tune
2. NEXT-TOKEN PREDICTION IS ENOUGH:
- Simple objective, powerful results
- Learns both understanding and generation
- No need for complex multi-task objectives
3. IN-CONTEXT LEARNING:
- Can adapt to new tasks from examples in context
- No fine-tuning needed (for large models)
- More flexible than task-specific models
4. GENERATION IS KING:
- Most useful applications involve generation
- Chat, code completion, writing assistance
- Encoder-only models can't do this naturally
5. EMPIRICAL SUCCESS:
- GPT-3 showed massive scale works
- Scaling laws are clean and predictable
- Industry standardized on decoder-only
"""
# Scaling laws for decoder-only models
def plot_scaling_laws():
"""Visualize scaling laws for decoder-only models"""
import matplotlib.pyplot as plt
import numpy as np
# Parameters (in billions)
params = np.array([0.125, 0.35, 1.3, 6.7, 13, 175])
# Approximate loss (from scaling law papers)
# Loss ≈ (N/N0)^(-α) where α ≈ 0.076
N0 = 8.8e9
alpha = 0.076
loss = (params * 1e9 / N0) ** (-alpha) * 2.5 + 1.5
# Downstream task performance (approximate)
task_perf = 100 * (1 - np.exp(-params / 50))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# Plot 1: Training loss vs parameters
ax1.plot(params, loss, 'o-', linewidth=2, markersize=8)
ax1.set_xscale('log')
ax1.set_xlabel('Parameters (billions)')
ax1.set_ylabel('Training Loss')
ax1.set_title('Scaling Law: Loss Decreases Predictably')
ax1.grid(alpha=0.3)
# Plot 2: Task performance vs parameters
ax2.plot(params, task_perf, 's-', linewidth=2, markersize=8, color='green')
ax2.set_xscale('log')
ax2.set_xlabel('Parameters (billions)')
ax2.set_ylabel('Task Performance')
ax2.set_title('Downstream Task Performance vs Model Size')
ax2.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('decoder_scaling_laws.png', dpi=150)
print("Scaling laws visualization saved")
print("\nKey insight: Performance improves predictably with scale")
plot_scaling_laws()
# Compare model families
decoder_models_timeline = """
Decoder-Only Models Timeline:
2018: GPT (117M) - Proves pre-training + fine-tuning works
2019: GPT-2 (1.5B) - Zero-shot task transfer emerges
2020: GPT-3 (175B) - Few-shot learning, massive scale
2021: GPT-J (6B), GPT-NeoX (20B) - Open-source alternatives
2023: LLaMA (7B-65B) - Efficient, open-source
2023: LLaMA 2 (7B-70B) - Commercially usable
2023: Mistral (7B) - Sparse mixture of experts
2024: Many more...
Common thread: All decoder-only!
"""
print(decoder_models_timeline)
Practice Exercise
# Exercise: Implement a simple decoder-only model
class SimpleDecoderLM(nn.Module):
"""
Minimal decoder-only language model
Exercise: Complete the missing parts
"""
def __init__(self, vocab_size, d_model=512, num_layers=6, num_heads=8):
super().__init__()
# TODO: Add token embeddings
self.token_embed = nn.Embedding(vocab_size, d_model)
# TODO: Add position embeddings
self.pos_embed = nn.Embedding(1024, d_model)
# TODO: Add transformer blocks
self.blocks = nn.ModuleList([
LLaMABlock(d_model, num_heads) for _ in range(num_layers)
])
# TODO: Add output head
self.output_norm = RMSNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, input_ids):
batch_size, seq_len = input_ids.shape
# TODO: Compute embeddings
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
x = self.token_embed(input_ids) + self.pos_embed(positions)
# TODO: Apply transformer blocks
for block in self.blocks:
x = block(x)
# TODO: Compute logits
x = self.output_norm(x)
logits = self.lm_head(x)
return logits
# Test the model
model = SimpleDecoderLM(vocab_size=50000, d_model=512, num_layers=6)
dummy_input = torch.randint(0, 50000, (2, 10)) # batch_size=2, seq_len=10
output = model(dummy_input)
print(f"Model output shape: {output.shape}") # Should be [2, 10, 50000]
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
# Exercise questions
exercise_questions = """
Practice Exercises:
1. Why is causal masking necessary during training but not during
generation? What would happen without it?
2. Implement a function to compute the memory savings from KV caching
when generating 100 tokens from a 50-token prompt.
3. Compare: Calculate FLOPs for one forward pass with and without
KV caching for a 12-layer model generating 50 tokens.
4. Design: How would you modify the decoder architecture to handle
2x longer context efficiently?
5. Explain: Why does RoPE (rotary embeddings) generalize better to
longer sequences than learned position embeddings?
"""
print(exercise_questions)
Quiz
Further Reading
- Attention Is All You Need (Original Transformer)
- Language Models are Unsupervised Multitask Learners (GPT-2)
- LLaMA: Open and Efficient Foundation Language Models
- LLaMA 2: Open Foundation and Fine-Tuned Chat Models
- RoFormer: Enhanced Transformer with Rotary Position Embedding
- GLU Variants Improve Transformer