Mixture of Experts Deep Dive
Master the Mixture of Experts architecture that powers models like Mixtral, GPT-4, and Switch Transformers, enabling massive scale with sparse activation.
What You'll Learn: MoE models activate only a subset of parameters for each input, dramatically increasing model capacity while keeping inference costs manageable. This is how models scale to trillions of parameters.
Understanding Mixture of Experts
Core Concepts
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional
import numpy as np
class SimpleExpert(nn.Module):
"""A single expert network (FFN)"""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.activation = nn.GELU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through expert"""
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class SimpleMoE(nn.Module):
"""
Simple Mixture of Experts implementation
Instead of one large FFN, we have N smaller expert FFNs.
A router decides which experts process each token.
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_experts: int = 8,
top_k: int = 2
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# Create experts
self.experts = nn.ModuleList([
SimpleExpert(input_dim, hidden_dim, output_dim)
for _ in range(num_experts)
])
# Router network
self.router = nn.Linear(input_dim, num_experts)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass with expert routing
Args:
x: Input tensor [batch_size, seq_len, input_dim]
Returns:
output: Mixed expert outputs [batch_size, seq_len, output_dim]
router_logits: Router logits for load balancing loss
"""
batch_size, seq_len, input_dim = x.shape
# Flatten for routing
x_flat = x.view(-1, input_dim) # [batch_size * seq_len, input_dim]
# Compute router logits
router_logits = self.router(x_flat) # [batch_size * seq_len, num_experts]
# Get top-k experts for each token
routing_weights, selected_experts = torch.topk(
router_logits,
self.top_k,
dim=-1
)
# Normalize routing weights
routing_weights = F.softmax(routing_weights, dim=-1)
# Initialize output
output = torch.zeros(
batch_size * seq_len,
self.experts[0].fc2.out_features,
device=x.device
)
# Process each token through selected experts
for i in range(self.top_k):
# Get expert indices for this position
expert_idx = selected_experts[:, i]
# Process each expert
for expert_num in range(self.num_experts):
# Get tokens routed to this expert
expert_mask = (expert_idx == expert_num)
if expert_mask.any():
# Get inputs for this expert
expert_input = x_flat[expert_mask]
# Process through expert
expert_output = self.experts[expert_num](expert_input)
# Add weighted output
output[expert_mask] += (
routing_weights[expert_mask, i].unsqueeze(1) *
expert_output
)
# Reshape output
output = output.view(batch_size, seq_len, -1)
return output, router_logits
def load_balancing_loss(self, router_logits: torch.Tensor) -> torch.Tensor:
"""
Compute load balancing loss to encourage equal expert usage
Args:
router_logits: Router output [batch_size * seq_len, num_experts]
Returns:
Load balancing loss
"""
# Calculate the fraction of tokens routed to each expert
routing_weights = F.softmax(router_logits, dim=-1)
expert_usage = routing_weights.mean(dim=0) # [num_experts]
# Ideally, each expert should get 1/num_experts of the tokens
target_usage = 1.0 / self.num_experts
# Coefficient of variation squared
cv_squared = (expert_usage.var() / (expert_usage.mean() ** 2))
return cv_squared * self.num_experts
# Example usage
def demonstrate_moe():
"""Demonstrate MoE forward pass"""
# Create simple MoE
moe = SimpleMoE(
input_dim=512,
hidden_dim=2048,
output_dim=512,
num_experts=8,
top_k=2
)
# Create sample input
batch_size = 2
seq_len = 10
x = torch.randn(batch_size, seq_len, 512)
# Forward pass
output, router_logits = moe(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Router logits shape: {router_logits.shape}")
# Calculate load balancing loss
lb_loss = moe.load_balancing_loss(router_logits)
print(f"Load balancing loss: {lb_loss.item():.4f}")
# Analyze expert usage
routing_weights = F.softmax(router_logits, dim=-1)
expert_usage = routing_weights.mean(dim=0)
print("\nExpert usage distribution:")
for i, usage in enumerate(expert_usage):
print(f" Expert {i}: {usage.item()*100:.2f}%")
demonstrate_moe()
Advanced Router Mechanisms
Router Design: The router is critical in MoE - it determines which experts see which inputs. Different routing strategies offer tradeoffs between quality, efficiency, and load balancing.
Top-K Routing with Load Balancing
class AdvancedRouter(nn.Module):
"""Advanced router with multiple routing strategies"""
def __init__(
self,
input_dim: int,
num_experts: int,
top_k: int = 2,
noise_std: float = 0.1,
expert_capacity_factor: float = 1.25
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.noise_std = noise_std
self.expert_capacity_factor = expert_capacity_factor
self.router = nn.Linear(input_dim, num_experts)
def forward(
self,
x: torch.Tensor,
training: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Route inputs to experts with capacity constraints
Returns:
routing_weights: Weights for selected experts
selected_experts: Indices of selected experts
router_logits: Raw router outputs
"""
# Compute router logits
router_logits = self.router(x)
# Add noise during training for exploration
if training and self.noise_std > 0:
noise = torch.randn_like(router_logits) * self.noise_std
router_logits = router_logits + noise
# Top-k routing
routing_weights, selected_experts = torch.topk(
router_logits,
self.top_k,
dim=-1
)
# Softmax over selected experts
routing_weights = F.softmax(routing_weights, dim=-1)
return routing_weights, selected_experts, router_logits
def compute_auxiliary_loss(
self,
router_logits: torch.Tensor
) -> torch.Tensor:
"""
Compute auxiliary loss for load balancing
Combines:
1. Load balancing loss (encourage equal expert usage)
2. Router z-loss (encourage router confidence)
"""
# Load balancing loss
routing_weights = F.softmax(router_logits, dim=-1)
expert_usage = routing_weights.mean(dim=0)
# Importance: sum of probabilities for each expert
importance = expert_usage
# Load: fraction of tokens routed to each expert (for top-k)
_, selected = torch.topk(router_logits, self.top_k, dim=-1)
load = torch.zeros(self.num_experts, device=router_logits.device)
for i in range(self.num_experts):
load[i] = (selected == i).float().sum() / selected.numel()
# Load balancing loss: importance * load
load_balancing_loss = (importance * load).sum() * self.num_experts
# Router z-loss: encourages router to be confident
router_z_loss = torch.logsumexp(router_logits, dim=-1).pow(2).mean()
return load_balancing_loss + 0.01 * router_z_loss
class SwitchRouter(nn.Module):
"""
Switch Transformer routing (top-1 with capacity)
Switch routing uses only top-1 expert per token but with
capacity constraints to prevent overloading popular experts.
"""
def __init__(
self,
input_dim: int,
num_experts: int,
capacity_factor: float = 1.25,
jitter_noise: float = 0.01
):
super().__init__()
self.num_experts = num_experts
self.capacity_factor = capacity_factor
self.jitter_noise = jitter_noise
self.router = nn.Linear(input_dim, num_experts, bias=False)
def forward(
self,
x: torch.Tensor,
training: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Switch routing with capacity constraints
Returns:
dispatch_tensor: Sparse tensor for routing [tokens, experts]
combine_tensor: Sparse tensor for combining [experts, tokens]
aux_loss: Auxiliary loss for load balancing
"""
num_tokens = x.shape[0]
# Compute router logits
router_logits = self.router(x)
# Add jitter noise during training
if training:
router_logits = router_logits + torch.randn_like(router_logits) * self.jitter_noise
# Calculate expert capacity
expert_capacity = int(
self.capacity_factor * num_tokens / self.num_experts
)
# Get top-1 expert per token
router_probs = F.softmax(router_logits, dim=-1)
expert_weights, expert_indices = torch.max(router_probs, dim=-1)
# Create dispatch and combine tensors
dispatch_tensor = torch.zeros(
num_tokens,
self.num_experts,
device=x.device
)
combine_tensor = torch.zeros(
num_tokens,
self.num_experts,
device=x.device
)
# Track expert usage
expert_counts = torch.zeros(self.num_experts, device=x.device)
# Assign tokens to experts (respecting capacity)
for token_idx in range(num_tokens):
expert_idx = expert_indices[token_idx].item()
# Check capacity
if expert_counts[expert_idx] < expert_capacity:
dispatch_tensor[token_idx, expert_idx] = 1.0
combine_tensor[token_idx, expert_idx] = expert_weights[token_idx]
expert_counts[expert_idx] += 1
# Compute auxiliary loss
aux_loss = self._compute_load_balancing_loss(router_probs, expert_indices)
return dispatch_tensor, combine_tensor, aux_loss
def _compute_load_balancing_loss(
self,
router_probs: torch.Tensor,
expert_indices: torch.Tensor
) -> torch.Tensor:
"""Compute load balancing loss"""
num_tokens = router_probs.shape[0]
# Fraction of tokens routed to each expert
tokens_per_expert = torch.bincount(
expert_indices,
minlength=self.num_experts
).float() / num_tokens
# Fraction of router probability for each expert
prob_per_expert = router_probs.mean(dim=0)
# Load balancing loss
aux_loss = self.num_experts * (tokens_per_expert * prob_per_expert).sum()
return aux_loss
# Example usage
def demonstrate_routing():
"""Demonstrate different routing strategies"""
input_dim = 512
num_experts = 8
num_tokens = 100
# Create sample input
x = torch.randn(num_tokens, input_dim)
# Advanced router (top-k)
print("Top-K Routing:")
router_topk = AdvancedRouter(input_dim, num_experts, top_k=2)
weights, experts, logits = router_topk(x, training=True)
aux_loss = router_topk.compute_auxiliary_loss(logits)
print(f" Selected experts shape: {experts.shape}")
print(f" Routing weights shape: {weights.shape}")
print(f" Auxiliary loss: {aux_loss.item():.4f}")
# Switch router (top-1)
print("\nSwitch Routing:")
router_switch = SwitchRouter(input_dim, num_experts)
dispatch, combine, aux_loss = router_switch(x, training=True)
print(f" Dispatch tensor shape: {dispatch.shape}")
print(f" Combine tensor shape: {combine.shape}")
print(f" Auxiliary loss: {aux_loss.item():.4f}")
# Analyze expert usage
tokens_per_expert = dispatch.sum(dim=0)
print("\n Tokens per expert:")
for i, count in enumerate(tokens_per_expert):
print(f" Expert {i}: {count.item():.0f} tokens")
demonstrate_routing()
Switch Transformer Implementation
Switch Transformers: Google's Switch Transformer uses MoE with top-1 routing to scale to trillions of parameters while maintaining efficiency through sparse activation.
class SwitchTransformerLayer(nn.Module):
"""
Complete Switch Transformer layer with MoE FFN
"""
def __init__(
self,
d_model: int,
num_experts: int,
expert_hidden_dim: int,
num_heads: int = 8,
capacity_factor: float = 1.25,
dropout: float = 0.1
):
super().__init__()
self.d_model = d_model
self.num_experts = num_experts
# Self-attention (standard)
self.self_attn = nn.MultiheadAttention(
d_model,
num_heads,
dropout=dropout,
batch_first=True
)
# Layer norms
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# MoE FFN
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, expert_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(expert_hidden_dim, d_model)
)
for _ in range(num_experts)
])
# Router
self.router = SwitchRouter(
d_model,
num_experts,
capacity_factor=capacity_factor
)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
training: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass through Switch Transformer layer
Args:
x: Input tensor [batch_size, seq_len, d_model]
Returns:
output: Layer output
aux_loss: Auxiliary loss for load balancing
"""
batch_size, seq_len, d_model = x.shape
# Self-attention block
residual = x
x = self.norm1(x)
attn_output, _ = self.self_attn(x, x, x)
x = residual + self.dropout(attn_output)
# MoE FFN block
residual = x
x = self.norm2(x)
# Flatten for routing
x_flat = x.view(-1, d_model) # [batch_size * seq_len, d_model]
# Route to experts
dispatch_tensor, combine_tensor, aux_loss = self.router(
x_flat,
training=training
)
# Process through experts
expert_outputs = []
for expert_idx, expert in enumerate(self.experts):
# Get tokens assigned to this expert
expert_mask = dispatch_tensor[:, expert_idx] > 0
if expert_mask.any():
# Process through expert
expert_input = x_flat[expert_mask]
expert_output = expert(expert_input)
# Create full output tensor (zeros for non-assigned tokens)
full_expert_output = torch.zeros_like(x_flat)
full_expert_output[expert_mask] = expert_output
expert_outputs.append(full_expert_output)
else:
expert_outputs.append(torch.zeros_like(x_flat))
# Combine expert outputs
expert_outputs = torch.stack(expert_outputs, dim=-1) # [..., num_experts]
combined_output = (expert_outputs * combine_tensor.unsqueeze(1)).sum(dim=-1)
# Reshape
combined_output = combined_output.view(batch_size, seq_len, d_model)
# Add residual
output = residual + self.dropout(combined_output)
return output, aux_loss
class SwitchTransformer(nn.Module):
"""Complete Switch Transformer model"""
def __init__(
self,
vocab_size: int,
d_model: int = 512,
num_layers: int = 6,
num_experts: int = 8,
expert_hidden_dim: int = 2048,
num_heads: int = 8,
max_seq_len: int = 512,
dropout: float = 0.1,
aux_loss_weight: float = 0.01
):
super().__init__()
self.aux_loss_weight = aux_loss_weight
# Embeddings
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(max_seq_len, d_model)
# Transformer layers
self.layers = nn.ModuleList([
SwitchTransformerLayer(
d_model,
num_experts,
expert_hidden_dim,
num_heads,
dropout=dropout
)
for _ in range(num_layers)
])
# Output head
self.norm = nn.LayerNorm(d_model)
self.output = nn.Linear(d_model, vocab_size)
def forward(
self,
input_ids: torch.Tensor,
training: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass
Args:
input_ids: Input token IDs [batch_size, seq_len]
Returns:
logits: Output logits [batch_size, seq_len, vocab_size]
total_aux_loss: Total auxiliary loss
"""
batch_size, seq_len = input_ids.shape
# Create embeddings
token_embeds = self.token_embedding(input_ids)
positions = torch.arange(seq_len, device=input_ids.device)
position_embeds = self.position_embedding(positions)
x = token_embeds + position_embeds
# Process through layers
total_aux_loss = 0.0
for layer in self.layers:
x, aux_loss = layer(x, training=training)
total_aux_loss = total_aux_loss + aux_loss
# Output projection
x = self.norm(x)
logits = self.output(x)
# Scale auxiliary loss
total_aux_loss = total_aux_loss * self.aux_loss_weight
return logits, total_aux_loss
# Example usage
def train_switch_transformer():
"""Demonstrate Switch Transformer training"""
# Create model
model = SwitchTransformer(
vocab_size=10000,
d_model=512,
num_layers=4,
num_experts=8,
expert_hidden_dim=2048
)
# Create sample batch
batch_size = 4
seq_len = 32
input_ids = torch.randint(0, 10000, (batch_size, seq_len))
# Forward pass
logits, aux_loss = model(input_ids, training=True)
print(f"Input shape: {input_ids.shape}")
print(f"Output shape: {logits.shape}")
print(f"Auxiliary loss: {aux_loss.item():.4f}")
# Calculate total loss (including aux loss)
targets = torch.randint(0, 10000, (batch_size, seq_len))
ce_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1)
)
total_loss = ce_loss + aux_loss
print(f"Cross-entropy loss: {ce_loss.item():.4f}")
print(f"Total loss: {total_loss.item():.4f}")
train_switch_transformer()
Expert Specialization Analysis
Expert Specialization: In trained MoE models, experts often specialize in different types of inputs (e.g., math, code, general knowledge). Understanding this helps with debugging and optimization.
class ExpertAnalyzer:
"""Analyze expert behavior and specialization"""
def __init__(self, model: SwitchTransformer):
self.model = model
self.expert_activations = {i: [] for i in range(model.layers[0].num_experts)}
self.expert_inputs = {i: [] for i in range(model.layers[0].num_experts)}
def analyze_routing(
self,
input_ids: torch.Tensor,
input_texts: List[str]
):
"""Analyze which experts are activated for which inputs"""
# Hook to capture routing decisions
routing_decisions = []
def routing_hook(module, input, output):
dispatch_tensor, _, _ = output
routing_decisions.append(dispatch_tensor)
# Register hooks
hooks = []
for layer in self.model.layers:
hook = layer.router.register_forward_hook(routing_hook)
hooks.append(hook)
# Forward pass
with torch.no_grad():
self.model(input_ids, training=False)
# Remove hooks
for hook in hooks:
hook.remove()
# Analyze routing patterns
for layer_idx, dispatch in enumerate(routing_decisions):
print(f"\nLayer {layer_idx} routing:")
# dispatch: [batch_size * seq_len, num_experts]
tokens_per_expert = dispatch.sum(dim=0)
for expert_idx, count in enumerate(tokens_per_expert):
if count > 0:
print(f" Expert {expert_idx}: {count.item():.0f} tokens")
def analyze_expert_specialization(
self,
dataset: List[Tuple[torch.Tensor, str, str]]
):
"""
Analyze what types of inputs each expert specializes in
Args:
dataset: List of (input_ids, text, category) tuples
"""
# Track which experts activate for which categories
expert_categories = {i: [] for i in range(8)}
for input_ids, text, category in dataset:
# Get routing decisions
routing_decisions = []
def routing_hook(module, input, output):
dispatch_tensor, _, _ = output
routing_decisions.append(dispatch_tensor)
hooks = []
for layer in self.model.layers:
hook = layer.router.register_forward_hook(routing_hook)
hooks.append(hook)
with torch.no_grad():
self.model(input_ids.unsqueeze(0), training=False)
for hook in hooks:
hook.remove()
# Record which experts were used
dispatch = routing_decisions[0] # First layer
active_experts = dispatch.sum(dim=0) > 0
for expert_idx, is_active in enumerate(active_experts):
if is_active:
expert_categories[expert_idx].append(category)
# Analyze specialization
print("\nExpert Specialization Analysis:")
for expert_idx, categories in expert_categories.items():
if categories:
# Count category frequency
from collections import Counter
category_counts = Counter(categories)
print(f"\nExpert {expert_idx}:")
for cat, count in category_counts.most_common(3):
pct = count / len(categories) * 100
print(f" {cat}: {pct:.1f}%")
def visualize_expert_usage(self, routing_matrix: torch.Tensor):
"""Visualize expert usage patterns"""
import matplotlib.pyplot as plt
# routing_matrix: [num_tokens, num_experts]
expert_usage = routing_matrix.sum(dim=0).cpu().numpy()
plt.figure(figsize=(10, 6))
plt.bar(range(len(expert_usage)), expert_usage)
plt.xlabel('Expert Index')
plt.ylabel('Number of Tokens')
plt.title('Expert Usage Distribution')
plt.show()
return expert_usage
# Example usage
model = SwitchTransformer(vocab_size=10000, num_experts=8)
analyzer = ExpertAnalyzer(model)
# Analyze routing
input_ids = torch.randint(0, 10000, (2, 20))
texts = ["Sample text 1", "Sample text 2"]
analyzer.analyze_routing(input_ids, texts)
Production Optimization
class OptimizedMoE(nn.Module):
"""Production-optimized MoE with various efficiency improvements"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_experts: int = 8,
top_k: int = 2,
use_kernel_fusion: bool = True
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# Batched experts (more efficient than separate modules)
self.expert_weights_1 = nn.Parameter(
torch.randn(num_experts, input_dim, hidden_dim)
)
self.expert_weights_2 = nn.Parameter(
torch.randn(num_experts, hidden_dim, output_dim)
)
self.router = nn.Linear(input_dim, num_experts)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Optimized forward pass"""
batch_size, seq_len, input_dim = x.shape
# Route
x_flat = x.view(-1, input_dim)
router_logits = self.router(x_flat)
routing_weights, selected_experts = torch.topk(
router_logits,
self.top_k,
dim=-1
)
routing_weights = F.softmax(routing_weights, dim=-1)
# Batched expert computation (more efficient)
outputs = []
for k in range(self.top_k):
expert_idx = selected_experts[:, k]
# Gather expert weights
weights_1 = self.expert_weights_1[expert_idx] # [batch, input_dim, hidden_dim]
weights_2 = self.expert_weights_2[expert_idx] # [batch, hidden_dim, output_dim]
# Batch matrix multiply
hidden = torch.bmm(
x_flat.unsqueeze(1),
weights_1
).squeeze(1) # [batch, hidden_dim]
hidden = F.gelu(hidden)
output = torch.bmm(
hidden.unsqueeze(1),
weights_2
).squeeze(1) # [batch, output_dim]
# Weight by routing probability
output = output * routing_weights[:, k].unsqueeze(1)
outputs.append(output)
# Combine
final_output = torch.stack(outputs, dim=0).sum(dim=0)
return final_output.view(batch_size, seq_len, -1)
# Benchmark
def benchmark_moe():
"""Benchmark MoE performance"""
import time
model = OptimizedMoE(512, 2048, 512, num_experts=8)
x = torch.randn(4, 100, 512)
# Warmup
for _ in range(10):
_ = model(x)
# Benchmark
start = time.time()
for _ in range(100):
_ = model(x)
elapsed = time.time() - start
print(f"Average time per forward pass: {elapsed/100*1000:.2f}ms")
benchmark_moe()
Quiz
Test your understanding of Mixture of Experts:
Summary
In this lesson, you learned:
- MoE fundamentals: Sparse activation and expert routing
- Router mechanisms: Top-k routing, load balancing, and capacity constraints
- Switch Transformers: Google's efficient top-1 MoE architecture
- Expert specialization: How experts develop specializations during training
- Production optimization: Efficient MoE implementation techniques
Mixture of Experts is a key technique enabling the next generation of massive language models while keeping inference costs manageable.