Layer Normalization and Residual Connections
While attention mechanisms get most of the glory, layer normalization and residual connections are the unsung heroes that make training deep transformers possible. Without them, models like GPT-3 with 96 layers simply wouldn't train.
The Problem: Training Deep Networks
Vanishing and Exploding Gradients
In deep networks, gradients can become problematically small or large as they backpropagate through layers.
import torch
import torch.nn as nn
# Simulate gradient flow through 50 layers without residual connections
class DeepNetworkWithoutResiduals(nn.Module):
def __init__(self, d_model, num_layers):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(d_model, d_model) for _ in range(num_layers)
])
def forward(self, x):
for layer in self.layers:
x = torch.tanh(layer(x)) # Non-linearity
return x
# Test gradient flow
model = DeepNetworkWithoutResiduals(d_model=512, num_layers=50)
x = torch.randn(1, 512, requires_grad=True)
output = model(x)
loss = output.sum()
loss.backward()
# Check gradient magnitude at input
print(f"Input gradient magnitude: {x.grad.abs().mean():.10f}")
# Likely very close to 0 (vanishing gradient)
The Gradient Problem:
In a 50-layer network with tanh activations, gradients can shrink exponentially. If each layer multiplies the gradient by 0.5, after 50 layers you get 0.5^50 ≈ 10^-15, effectively zero. This makes the early layers impossible to train.
Internal Covariate Shift
As network parameters update during training, the distribution of inputs to each layer constantly changes, making training unstable.
Residual Connections: Highway to Gradient Flow
The Core Idea
Instead of learning a transformation F(x), learn a residual F(x) and add it to the input:
Output = x + F(x)
This creates a "highway" for gradients to flow directly backward.
Mathematical Intuition
During backpropagation:
∂Loss/∂x = ∂Loss/∂Output × ∂Output/∂x
= ∂Loss/∂Output × ∂(x + F(x))/∂x
= ∂Loss/∂Output × (1 + ∂F(x)/∂x)
The "+1" term ensures gradients can always flow backward, even if ∂F(x)/∂x becomes small.
Implementation
class TransformerBlockWithResiduals(nn.Module):
"""Transformer block with residual connections."""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Residual connection around attention
attn_output, _ = self.attention(x, x, x)
x = x + self.dropout(attn_output) # ← Residual connection
# Residual connection around feed-forward
ff_output = self.feed_forward(x)
x = x + self.dropout(ff_output) # ← Residual connection
return x
# Test gradient flow with residuals
class DeepNetworkWithResiduals(nn.Module):
def __init__(self, d_model, num_layers):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(d_model, d_model) for _ in range(num_layers)
])
def forward(self, x):
for layer in self.layers:
# Residual connection
x = x + torch.tanh(layer(x))
return x
# Test gradient flow
model = DeepNetworkWithResiduals(d_model=512, num_layers=50)
x = torch.randn(1, 512, requires_grad=True)
output = model(x)
loss = output.sum()
loss.backward()
print(f"Input gradient magnitude with residuals: {x.grad.abs().mean():.6f}")
# Much larger than without residuals!
Residual Learning:
Residual connections were introduced in ResNet (2015) for computer vision. The key insight: it's easier to learn a small adjustment to the input (residual) than to learn the entire transformation from scratch. The network can always choose F(x) = 0 to pass the input through unchanged.
Layer Normalization: Stabilizing Activations
Why Not Batch Normalization?
Batch normalization works well for CNNs but has issues with transformers:
- Variable sequence lengths: Different sequences have different lengths
- Small batch sizes: NLP often uses small batches due to memory constraints
- Recurrent dependencies: Statistics change across time steps
Layer Normalization Formula
LayerNorm normalizes across the feature dimension (d_model) for each example independently:
LN(x) = γ × (x - μ) / √(σ² + ε) + β
where:
μ = mean(x) across features
σ² = variance(x) across features
γ, β = learnable scale and shift parameters
ε = small constant for numerical stability (e.g., 10^-5)
Implementation from Scratch
class LayerNorm(nn.Module):
"""Layer Normalization implementation from scratch."""
def __init__(self, d_model, eps=1e-5):
"""
Args:
d_model: Model dimension
eps: Small constant for numerical stability
"""
super().__init__()
self.eps = eps
# Learnable parameters
self.gamma = nn.Parameter(torch.ones(d_model)) # Scale
self.beta = nn.Parameter(torch.zeros(d_model)) # Shift
def forward(self, x):
"""
Args:
x: Input tensor (batch_size, seq_len, d_model)
Returns:
Normalized tensor (batch_size, seq_len, d_model)
"""
# Compute mean and variance across the last dimension (features)
mean = x.mean(dim=-1, keepdim=True) # (batch, seq_len, 1)
var = x.var(dim=-1, keepdim=True, unbiased=False) # (batch, seq_len, 1)
# Normalize
x_norm = (x - mean) / torch.sqrt(var + self.eps)
# Scale and shift
output = self.gamma * x_norm + self.beta
return output
# Test LayerNorm
batch_size, seq_len, d_model = 2, 10, 512
ln = LayerNorm(d_model)
x = torch.randn(batch_size, seq_len, d_model) * 10 + 5 # Arbitrary mean and std
output = ln(x)
# Verify normalization
print(f"Input mean: {x.mean():.4f}, std: {x.std():.4f}")
print(f"Output mean: {output.mean():.4f}, std: {output.std():.4f}")
print(f"Per-example output mean: {output[0].mean(dim=-1)}") # Should be close to 0
print(f"Per-example output std: {output[0].std(dim=-1)}") # Should be close to 1
Why Layer Normalization Works
1. Reduces Internal Covariate Shift: Keeps activations in a stable range throughout training.
2. Smooths Loss Landscape: Makes the optimization landscape easier to navigate.
3. Allows Higher Learning Rates: More stable training enables faster convergence.
import matplotlib.pyplot as plt
import numpy as np
# Visualize effect of LayerNorm on activation distributions
def plot_activation_distribution():
# Generate activations through 10 layers without normalization
x_no_norm = torch.randn(1000, 512)
activations_no_norm = []
for _ in range(10):
x_no_norm = torch.tanh(nn.Linear(512, 512)(x_no_norm))
activations_no_norm.append(x_no_norm.detach().flatten().numpy())
# Generate activations through 10 layers with LayerNorm
x_with_norm = torch.randn(1000, 512)
ln = LayerNorm(512)
activations_with_norm = []
for _ in range(10):
x_with_norm = torch.tanh(nn.Linear(512, 512)(x_with_norm))
x_with_norm = ln(x_with_norm)
activations_with_norm.append(x_with_norm.detach().flatten().numpy())
# Plot distributions
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# Without normalization
for i, acts in enumerate(activations_no_norm):
ax1.hist(acts, bins=50, alpha=0.3, label=f'Layer {i+1}')
ax1.set_title('Activation Distribution WITHOUT LayerNorm')
ax1.set_xlabel('Activation Value')
ax1.set_ylabel('Frequency')
ax1.legend()
# With normalization
for i, acts in enumerate(activations_with_norm):
ax2.hist(acts, bins=50, alpha=0.3, label=f'Layer {i+1}')
ax2.set_title('Activation Distribution WITH LayerNorm')
ax2.set_xlabel('Activation Value')
ax2.set_ylabel('Frequency')
ax2.legend()
plt.tight_layout()
plt.show()
plot_activation_distribution()
Observation:
Without LayerNorm, activation distributions shift and shrink through layers (collapsing to near zero). With LayerNorm, distributions remain stable and centered, enabling effective gradient flow.
Pre-Norm vs Post-Norm
Post-Norm (Original Transformer)
class PostNormTransformerBlock(nn.Module):
"""Original 'Attention Is All You Need' architecture."""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, num_heads)
self.norm1 = nn.LayerNorm(d_model)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Attention sub-layer
attn_output, _ = self.attention(x, x, x)
x = self.norm1(x + self.dropout(attn_output)) # Add then normalize
# Feed-forward sub-layer
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output)) # Add then normalize
return x
Pre-Norm (Modern Transformers)
class PreNormTransformerBlock(nn.Module):
"""Modern architecture used in GPT, LLaMA, etc."""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, num_heads)
self.norm1 = nn.LayerNorm(d_model)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Attention sub-layer
normed_x = self.norm1(x) # Normalize first
attn_output, _ = self.attention(normed_x, normed_x, normed_x)
x = x + self.dropout(attn_output)
# Feed-forward sub-layer
normed_x = self.norm2(x) # Normalize first
ff_output = self.feed_forward(normed_x)
x = x + self.dropout(ff_output)
return x
Pre-Norm Advantages
1. Easier Training: Gradients flow more smoothly through the network.
2. No Warm-up Required: Can use high learning rates from the start.
3. Better for Very Deep Models: Scales better to 50+ layers.
Modern Choice:
GPT-2/3, LLaMA, and most modern transformers use Pre-Norm architecture. It's become the de facto standard for large language models due to superior training stability.
RMSNorm: Simplified Layer Normalization
RMSNorm (used in LLaMA) simplifies LayerNorm by removing mean centering:
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization.
Simplified version of LayerNorm used in LLaMA.
"""
def __init__(self, d_model, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
"""
Args:
x: (batch_size, seq_len, d_model)
"""
# Compute RMS
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
# Normalize and scale
x_norm = x / rms
return self.weight * x_norm
# Compare LayerNorm vs RMSNorm
ln = LayerNorm(512)
rms = RMSNorm(512)
x = torch.randn(2, 10, 512)
ln_output = ln(x)
rms_output = rms(x)
print(f"LayerNorm output mean: {ln_output.mean():.6f}, std: {ln_output.std():.6f}")
print(f"RMSNorm output mean: {rms_output.mean():.6f}, std: {rms_output.std():.6f}")
# RMSNorm is faster (no mean computation)
import time
x = torch.randn(100, 2048, 4096).cuda()
ln = LayerNorm(4096).cuda()
rms = RMSNorm(4096).cuda()
# Benchmark
start = time.time()
for _ in range(100):
_ = ln(x)
ln_time = time.time() - start
start = time.time()
for _ in range(100):
_ = rms(x)
rms_time = time.time() - start
print(f"LayerNorm time: {ln_time:.4f}s")
print(f"RMSNorm time: {rms_time:.4f}s")
print(f"Speedup: {ln_time/rms_time:.2f}x")
Complete Transformer Block
Putting it all together:
class ModernTransformerBlock(nn.Module):
"""
Complete transformer block with:
- Pre-layer normalization
- Residual connections
- Multi-head attention
- Feed-forward network
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1, use_rms_norm=False):
super().__init__()
# Choose normalization type
norm_class = RMSNorm if use_rms_norm else nn.LayerNorm
self.norm1 = norm_class(d_model)
self.norm2 = norm_class(d_model)
# Attention
self.attention = nn.MultiheadAttention(
d_model,
num_heads,
dropout=dropout,
batch_first=True
)
# Feed-forward
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(), # Modern activation
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
def forward(self, x, mask=None):
# Pre-norm attention with residual
normed = self.norm1(x)
attn_output, _ = self.attention(normed, normed, normed, attn_mask=mask)
x = x + attn_output
# Pre-norm feed-forward with residual
normed = self.norm2(x)
ff_output = self.feed_forward(normed)
x = x + ff_output
return x
# Build a complete transformer
class Transformer(nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.layers = nn.ModuleList([
ModernTransformerBlock(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.final_norm = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
for layer in self.layers:
x = layer(x, mask)
return self.final_norm(x)
# Example: GPT-2 small configuration
model = Transformer(
num_layers=12,
d_model=768,
num_heads=12,
d_ff=3072, # 4 * d_model
dropout=0.1
)
x = torch.randn(2, 1024, 768) # (batch, seq_len, d_model)
output = model(x)
print(f"Output shape: {output.shape}")
Summary
Layer normalization and residual connections are essential for training deep transformers:
Residual Connections:
- Create gradient highways for deep networks
- Enable training of 50+ layer models
- Allow learning residual functions (easier than full transformations)
Layer Normalization:
- Stabilizes activation distributions
- Reduces internal covariate shift
- Enables higher learning rates
Modern Best Practices:
- Use Pre-Norm architecture (GPT-2/3, LLaMA)
- Consider RMSNorm for efficiency (LLaMA, Mistral)
- Always include residual connections around sub-layers
Without these techniques, modern large language models would be impossible to train.