Mistral and Mixture of Experts
Mixture of Experts (MoE) is a powerful technique that dramatically increases model capacity while keeping computational cost manageable. Mistral AI pioneered production-ready MoE models with Mixtral 8x7B, achieving GPT-3.5 level performance with far less compute.
The MoE Concept
The Core Idea
Instead of one feed-forward network, use multiple expert networks and route each token to a subset of experts.
# Traditional FFN
output = FFN(x) # All tokens use same FFN
# Mixture of Experts
router_weights = Router(x) # Decide which experts to use
top_k_experts = select_top_k(router_weights, k=2)
output = weighted_sum([Expert_i(x) for i in top_k_experts])
Benefits:
- More capacity: 8 experts = 8x parameters
- Same compute: Only use 2/8 experts per token
- Specialization: Experts learn different patterns
Sparse Activation:
MoE models have many parameters but only activate a fraction per token. Mixtral 8x7B has 47B total parameters but only uses ~13B per token (similar to a 13B dense model in compute).
Components of MoE
1. Expert Networks
Each expert is typically a standard feed-forward network:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Expert(nn.Module):
"""
Single expert network (feed-forward network).
"""
def __init__(self, dim, hidden_dim, dropout=0.0):
"""
Args:
dim: Model dimension
hidden_dim: Hidden layer dimension (usually 4 * dim)
dropout: Dropout rate
"""
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim, bias=False)
self.fc2 = nn.Linear(hidden_dim, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x: (batch, seq_len, dim) or (num_tokens, dim)
Returns:
Same shape as input
"""
# Standard FFN: up-projection, activation, down-projection
h = F.gelu(self.fc1(x))
h = self.dropout(h)
return self.fc2(h)
# Create multiple experts
num_experts = 8
dim = 512
hidden_dim = 2048
experts = nn.ModuleList([
Expert(dim, hidden_dim) for _ in range(num_experts)
])
print(f"Created {num_experts} experts")
print(f"Each expert has {sum(p.numel() for p in experts[0].parameters()):,} parameters")
print(f"Total: {sum(p.numel() for p in experts.parameters()):,} parameters")
2. Router Network
The router decides which experts process each token:
class Router(nn.Module):
"""
Router network for expert selection.
Learns to route each token to the most relevant experts.
"""
def __init__(self, dim, num_experts):
"""
Args:
dim: Model dimension
num_experts: Number of experts to choose from
"""
super().__init__()
self.num_experts = num_experts
# Linear layer to compute expert scores
self.gate = nn.Linear(dim, num_experts, bias=False)
def forward(self, x):
"""
Compute routing weights for each token.
Args:
x: (batch, seq_len, dim) or (num_tokens, dim)
Returns:
router_logits: (num_tokens, num_experts)
"""
# Reshape to (num_tokens, dim)
original_shape = x.shape
x = x.view(-1, x.shape[-1])
# Compute logits for each expert
router_logits = self.gate(x)
return router_logits
# Test router
router = Router(dim=512, num_experts=8)
x = torch.randn(2, 10, 512) # (batch=2, seq_len=10, dim=512)
router_logits = router(x)
print(f"Input shape: {x.shape}")
print(f"Router logits shape: {router_logits.shape}") # (20, 8)
# Convert to probabilities
router_probs = F.softmax(router_logits, dim=-1)
print(f"\nExample routing probabilities (token 0):")
print(router_probs[0])
3. Top-K Gating
Select the top-k experts for each token:
def top_k_gating(router_logits, k=2, use_softmax=True, use_noise=False):
"""
Select top-k experts for each token.
Args:
router_logits: (num_tokens, num_experts)
k: Number of experts to select
use_softmax: Whether to apply softmax to selected experts
use_noise: Add noise during training (for exploration)
Returns:
indices: (num_tokens, k) - Selected expert indices
weights: (num_tokens, k) - Routing weights
"""
num_tokens, num_experts = router_logits.shape
# Add noise during training (encourages exploration)
if use_noise and router_logits.requires_grad:
noise = torch.randn_like(router_logits) * 0.1
router_logits = router_logits + noise
# Get top-k expert indices and values
top_k_logits, top_k_indices = torch.topk(router_logits, k, dim=-1)
# top_k_indices: (num_tokens, k)
# top_k_logits: (num_tokens, k)
# Compute routing weights
if use_softmax:
# Normalize only over selected experts
weights = F.softmax(top_k_logits, dim=-1)
else:
# Use raw logits (sometimes used with auxiliary losses)
weights = top_k_logits
return top_k_indices, weights
# Test top-k gating
router_logits = torch.randn(5, 8) # 5 tokens, 8 experts
indices, weights = top_k_gating(router_logits, k=2)
print("Top-k Gating (k=2):")
print(f"Selected expert indices:\n{indices}")
print(f"\nRouting weights:\n{weights}")
print(f"\nWeights sum to 1: {weights.sum(dim=-1)}")
Load Balancing Challenge:
Without constraints, the router might send all tokens to the same few experts, wasting the other experts. Load balancing losses encourage even distribution across experts.
Complete MoE Layer
class MixtureOfExperts(nn.Module):
"""
Complete Mixture of Experts layer.
Routes each token to top-k experts and combines their outputs.
"""
def __init__(
self,
dim,
num_experts=8,
expert_capacity=None,
k=2,
hidden_dim=None,
dropout=0.0
):
"""
Args:
dim: Model dimension
num_experts: Number of expert networks
expert_capacity: Max tokens per expert (for load balancing)
k: Number of experts to route each token to
hidden_dim: Expert hidden dimension (default: 4 * dim)
dropout: Dropout rate
"""
super().__init__()
self.dim = dim
self.num_experts = num_experts
self.k = k
self.expert_capacity = expert_capacity
if hidden_dim is None:
hidden_dim = 4 * dim
# Create experts
self.experts = nn.ModuleList([
Expert(dim, hidden_dim, dropout) for _ in range(num_experts)
])
# Router
self.router = Router(dim, num_experts)
def forward(self, x):
"""
Route tokens to experts and combine outputs.
Args:
x: (batch, seq_len, dim)
Returns:
output: (batch, seq_len, dim)
router_probs: Expert selection probabilities (for auxiliary loss)
"""
batch_size, seq_len, dim = x.shape
original_shape = x.shape
# Reshape to (num_tokens, dim)
x_flat = x.view(-1, dim)
num_tokens = x_flat.shape[0]
# Route tokens to experts
router_logits = self.router(x) # (num_tokens, num_experts)
router_probs = F.softmax(router_logits, dim=-1)
# Select top-k experts
top_k_indices, top_k_weights = top_k_gating(router_logits, self.k)
# top_k_indices: (num_tokens, k)
# top_k_weights: (num_tokens, k)
# Initialize output
output = torch.zeros_like(x_flat)
# Process each expert
for expert_idx in range(self.num_experts):
# Find tokens routed to this expert
expert_mask = (top_k_indices == expert_idx) # (num_tokens, k)
# Get token indices and their routing weights
token_indices, k_indices = torch.where(expert_mask)
if len(token_indices) == 0:
continue # No tokens for this expert
# Get tokens for this expert
expert_input = x_flat[token_indices] # (num_routed_tokens, dim)
# Apply expert
expert_output = self.experts[expert_idx](expert_input)
# Get routing weights for these tokens
expert_weights = top_k_weights[token_indices, k_indices] # (num_routed_tokens,)
# Add weighted expert output to result
output[token_indices] += expert_weights.unsqueeze(-1) * expert_output
# Reshape back to original shape
output = output.view(original_shape)
return output, router_probs
# Test MoE layer
moe = MixtureOfExperts(
dim=512,
num_experts=8,
k=2,
hidden_dim=2048
)
x = torch.randn(2, 10, 512)
output, router_probs = moe(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Router probs shape: {router_probs.shape}")
# Analyze expert usage
expert_usage = router_probs.sum(dim=0)
print(f"\nExpert usage (total routing probability):")
for i, usage in enumerate(expert_usage):
print(f" Expert {i}: {usage.item():.2f}")
Load Balancing
The Problem
Without constraints, some experts get overused while others are ignored:
def demonstrate_load_imbalance():
"""Show how routers can become imbalanced."""
# Simulate imbalanced routing
num_tokens = 1000
num_experts = 8
# Poorly balanced: most tokens go to expert 0 and 1
imbalanced_probs = torch.zeros(num_tokens, num_experts)
imbalanced_probs[:, 0] = 0.5 # 50% to expert 0
imbalanced_probs[:, 1] = 0.4 # 40% to expert 1
imbalanced_probs[:, 2:] = 0.1 / 6 # Remaining 10% split across others
expert_loads = imbalanced_probs.sum(dim=0)
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
plt.bar(range(num_experts), expert_loads.numpy())
plt.xlabel('Expert Index')
plt.ylabel('Total Load (tokens)')
plt.title('Imbalanced Expert Utilization')
plt.axhline(num_tokens / num_experts, color='r', linestyle='--',
label='Ideal Balance')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
print("Expert utilization (% of tokens):")
for i, load in enumerate(expert_loads):
print(f" Expert {i}: {load / num_tokens * 100:.1f}%")
demonstrate_load_imbalance()
Solution: Auxiliary Loss
Encourage balanced expert usage:
def load_balancing_loss(router_probs, top_k_indices, num_experts):
"""
Compute load balancing auxiliary loss.
Encourages uniform distribution of tokens across experts.
Args:
router_probs: (num_tokens, num_experts) - Router probabilities
top_k_indices: (num_tokens, k) - Selected expert indices
num_experts: Number of experts
Returns:
loss: Scalar load balancing loss
"""
num_tokens = router_probs.shape[0]
# Compute fraction of tokens routed to each expert
expert_counts = torch.zeros(num_experts, device=router_probs.device)
for expert_idx in range(num_experts):
expert_mask = (top_k_indices == expert_idx)
expert_counts[expert_idx] = expert_mask.float().sum()
fraction_routed = expert_counts / (num_tokens * top_k_indices.shape[1])
# Compute average router probability for each expert
avg_router_prob = router_probs.mean(dim=0)
# Load balancing loss: encourage uniform distribution
# Loss is high when fraction_routed and avg_router_prob differ across experts
loss = num_experts * (fraction_routed * avg_router_prob).sum()
return loss
# Test load balancing loss
router_probs = torch.randn(100, 8).softmax(dim=-1)
top_k_indices = router_probs.topk(2, dim=-1).indices
lb_loss = load_balancing_loss(router_probs, top_k_indices, num_experts=8)
print(f"Load balancing loss: {lb_loss.item():.4f}")
Switch Transformer Approach:
Google's Switch Transformer uses expert capacity limits: each expert can only process a maximum number of tokens. Excess tokens skip the MoE layer. This ensures balanced compute but can drop information.
Mixtral Approach:
Mixtral doesn't use hard capacity limits. Instead, it relies on the auxiliary loss and careful initialization to maintain balance.
Mistral Architecture
Mistral 7B is a dense model with innovations, while Mixtral 8x7B uses MoE:
class MistralBlock(nn.Module):
"""
Mistral transformer block.
Combines RMSNorm, SwiGLU, RoPE (like LLaMA) with optional MoE.
"""
def __init__(
self,
dim,
num_heads,
num_kv_heads=None,
hidden_dim=None,
num_experts=None, # If None, use dense FFN; else use MoE
moe_k=2,
dropout=0.0
):
super().__init__()
# Pre-normalization
self.attention_norm = nn.RMSNorm(dim) # Using PyTorch 2.0+
self.ffn_norm = nn.RMSNorm(dim)
# Attention (same as LLaMA with GQA)
self.attention = MultiHeadAttention(dim, num_heads, num_kv_heads)
# Feed-forward: MoE or dense
if num_experts is not None:
# Sparse MoE layer
self.feed_forward = MixtureOfExperts(
dim=dim,
num_experts=num_experts,
k=moe_k,
hidden_dim=hidden_dim,
dropout=dropout
)
self.use_moe = True
else:
# Dense FFN
self.feed_forward = SwiGLU(dim, hidden_dim, dropout)
self.use_moe = False
def forward(self, x, mask=None):
"""
Args:
x: (batch, seq_len, dim)
mask: Attention mask
Returns:
output: (batch, seq_len, dim)
aux_loss: Auxiliary loss (load balancing) if using MoE
"""
# Attention with residual
h = x + self.attention(self.attention_norm(x), mask)
# FFN with residual
if self.use_moe:
ffn_out, router_probs = self.feed_forward(self.ffn_norm(h))
out = h + ffn_out
# Compute auxiliary loss
# (In practice, this would use the actual routing decisions)
aux_loss = 0.01 * router_probs.var() # Simplified
else:
out = h + self.feed_forward(self.ffn_norm(h))
aux_loss = 0.0
return out, aux_loss
# Mistral 7B (dense)
mistral_7b = MistralBlock(
dim=4096,
num_heads=32,
num_kv_heads=8, # GQA
hidden_dim=14336,
num_experts=None # Dense
)
# Mixtral 8x7B (MoE)
mixtral_8x7b = MistralBlock(
dim=4096,
num_heads=32,
num_kv_heads=8,
hidden_dim=14336,
num_experts=8, # 8 experts
moe_k=2 # Route to top-2
)
# Compare parameter counts
def count_params(model):
return sum(p.numel() for p in model.parameters())
print(f"Mistral 7B block: {count_params(mistral_7b) / 1e6:.1f}M parameters")
print(f"Mixtral 8x7B block: {count_params(mixtral_8x7b) / 1e6:.1f}M parameters")
print(f"Mixtral has {count_params(mixtral_8x7b) / count_params(mistral_7b):.1f}x more parameters")
print(f"But only uses ~2/8 experts per token!")
Sliding Window Attention
Mistral also introduces sliding window attention for efficiency:
def create_sliding_window_mask(seq_len, window_size):
"""
Create attention mask for sliding window attention.
Each position can only attend to positions within a local window.
Args:
seq_len: Sequence length
window_size: Size of attention window
Returns:
mask: (seq_len, seq_len) boolean mask
"""
# Create distance matrix
positions = torch.arange(seq_len).unsqueeze(0)
distances = positions.T - positions
# Mask out positions beyond window
mask = (distances <= 0) & (distances >= -window_size)
return mask
# Visualize sliding window
seq_len = 20
window_size = 4
mask = create_sliding_window_mask(seq_len, window_size)
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
plt.imshow(mask.numpy(), cmap='Blues', aspect='auto')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title(f'Sliding Window Attention (window={window_size})')
plt.colorbar(label='Can Attend')
plt.show()
print(f"Each token can attend to {window_size} previous tokens")
print(f"Receptive field after N layers: {window_size} × N")
Sliding Window Benefits:
- Linear memory: O(n × window) instead of O(n²)
- Long context: Stack layers to increase receptive field
- Local focus: Most attention is local anyway
Mistral uses window_size=4096, giving effective context of ~131k tokens with 32 layers.
Training MoE Models
Key considerations:
class MoETrainer:
"""Helper for training MoE models."""
def __init__(self, model, aux_loss_weight=0.01):
self.model = model
self.aux_loss_weight = aux_loss_weight
def training_step(self, batch):
"""
Single training step with auxiliary loss.
Args:
batch: Input batch
Returns:
total_loss: Main loss + auxiliary loss
"""
# Forward pass
logits, aux_losses = self.model(batch['input_ids'])
# Main loss (cross-entropy)
main_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
batch['labels'].view(-1)
)
# Auxiliary loss (load balancing)
aux_loss = sum(aux_losses) if isinstance(aux_losses, list) else aux_losses
# Total loss
total_loss = main_loss + self.aux_loss_weight * aux_loss
return {
'total_loss': total_loss,
'main_loss': main_loss,
'aux_loss': aux_loss
}
Summary
Mixture of Experts:
- Multiple expert networks with learned routing
- Sparse activation (only k out of n experts per token)
- Dramatic capacity increase with manageable compute
Key Challenges:
- Load balancing: Prevent expert collapse
- Training instability: Routing can be unstable
- Communication: Expert parallelism requires careful engineering
Mistral Innovations:
- Mistral 7B: Dense model with GQA and sliding window attention
- Mixtral 8x7B: Sparse MoE achieving 47B parameter capacity with 13B compute per token
When to Use MoE:
- Need large model capacity
- Can tolerate training complexity
- Have infrastructure for expert parallelism
- Want better quality/compute tradeoff than dense models