Back
advanced
Modern Architectures

Mistral and Mixture of Experts

Understand Mixture of Experts (MoE) architecture and the Mistral model family. Complete implementation of sparse expert routing and load balancing.

30 min read· MoE· Mistral· Sparse Models· Expert Routing

Mistral and Mixture of Experts

Mixture of Experts (MoE) is a powerful technique that dramatically increases model capacity while keeping computational cost manageable. Mistral AI pioneered production-ready MoE models with Mixtral 8x7B, achieving GPT-3.5 level performance with far less compute.

The MoE Concept

The Core Idea

Instead of one feed-forward network, use multiple expert networks and route each token to a subset of experts.

python
# Traditional FFN
output = FFN(x)  # All tokens use same FFN

# Mixture of Experts
router_weights = Router(x)  # Decide which experts to use
top_k_experts = select_top_k(router_weights, k=2)
output = weighted_sum([Expert_i(x) for i in top_k_experts])

Benefits:

  • More capacity: 8 experts = 8x parameters
  • Same compute: Only use 2/8 experts per token
  • Specialization: Experts learn different patterns

Sparse Activation:

MoE models have many parameters but only activate a fraction per token. Mixtral 8x7B has 47B total parameters but only uses ~13B per token (similar to a 13B dense model in compute).

Components of MoE

1. Expert Networks

Each expert is typically a standard feed-forward network:

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class Expert(nn.Module):
    """
    Single expert network (feed-forward network).
    """

    def __init__(self, dim, hidden_dim, dropout=0.0):
        """
        Args:
            dim: Model dimension
            hidden_dim: Hidden layer dimension (usually 4 * dim)
            dropout: Dropout rate
        """
        super().__init__()

        self.fc1 = nn.Linear(dim, hidden_dim, bias=False)
        self.fc2 = nn.Linear(hidden_dim, dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, dim) or (num_tokens, dim)

        Returns:
            Same shape as input
        """
        # Standard FFN: up-projection, activation, down-projection
        h = F.gelu(self.fc1(x))
        h = self.dropout(h)
        return self.fc2(h)


# Create multiple experts
num_experts = 8
dim = 512
hidden_dim = 2048

experts = nn.ModuleList([
    Expert(dim, hidden_dim) for _ in range(num_experts)
])

print(f"Created {num_experts} experts")
print(f"Each expert has {sum(p.numel() for p in experts[0].parameters()):,} parameters")
print(f"Total: {sum(p.numel() for p in experts.parameters()):,} parameters")

2. Router Network

The router decides which experts process each token:

python
class Router(nn.Module):
    """
    Router network for expert selection.

    Learns to route each token to the most relevant experts.
    """

    def __init__(self, dim, num_experts):
        """
        Args:
            dim: Model dimension
            num_experts: Number of experts to choose from
        """
        super().__init__()
        self.num_experts = num_experts

        # Linear layer to compute expert scores
        self.gate = nn.Linear(dim, num_experts, bias=False)

    def forward(self, x):
        """
        Compute routing weights for each token.

        Args:
            x: (batch, seq_len, dim) or (num_tokens, dim)

        Returns:
            router_logits: (num_tokens, num_experts)
        """
        # Reshape to (num_tokens, dim)
        original_shape = x.shape
        x = x.view(-1, x.shape[-1])

        # Compute logits for each expert
        router_logits = self.gate(x)

        return router_logits


# Test router
router = Router(dim=512, num_experts=8)
x = torch.randn(2, 10, 512)  # (batch=2, seq_len=10, dim=512)

router_logits = router(x)
print(f"Input shape: {x.shape}")
print(f"Router logits shape: {router_logits.shape}")  # (20, 8)

# Convert to probabilities
router_probs = F.softmax(router_logits, dim=-1)
print(f"\nExample routing probabilities (token 0):")
print(router_probs[0])

3. Top-K Gating

Select the top-k experts for each token:

python
def top_k_gating(router_logits, k=2, use_softmax=True, use_noise=False):
    """
    Select top-k experts for each token.

    Args:
        router_logits: (num_tokens, num_experts)
        k: Number of experts to select
        use_softmax: Whether to apply softmax to selected experts
        use_noise: Add noise during training (for exploration)

    Returns:
        indices: (num_tokens, k) - Selected expert indices
        weights: (num_tokens, k) - Routing weights
    """
    num_tokens, num_experts = router_logits.shape

    # Add noise during training (encourages exploration)
    if use_noise and router_logits.requires_grad:
        noise = torch.randn_like(router_logits) * 0.1
        router_logits = router_logits + noise

    # Get top-k expert indices and values
    top_k_logits, top_k_indices = torch.topk(router_logits, k, dim=-1)
    # top_k_indices: (num_tokens, k)
    # top_k_logits: (num_tokens, k)

    # Compute routing weights
    if use_softmax:
        # Normalize only over selected experts
        weights = F.softmax(top_k_logits, dim=-1)
    else:
        # Use raw logits (sometimes used with auxiliary losses)
        weights = top_k_logits

    return top_k_indices, weights


# Test top-k gating
router_logits = torch.randn(5, 8)  # 5 tokens, 8 experts

indices, weights = top_k_gating(router_logits, k=2)

print("Top-k Gating (k=2):")
print(f"Selected expert indices:\n{indices}")
print(f"\nRouting weights:\n{weights}")
print(f"\nWeights sum to 1: {weights.sum(dim=-1)}")

Load Balancing Challenge:

Without constraints, the router might send all tokens to the same few experts, wasting the other experts. Load balancing losses encourage even distribution across experts.

Complete MoE Layer

python
class MixtureOfExperts(nn.Module):
    """
    Complete Mixture of Experts layer.

    Routes each token to top-k experts and combines their outputs.
    """

    def __init__(
        self,
        dim,
        num_experts=8,
        expert_capacity=None,
        k=2,
        hidden_dim=None,
        dropout=0.0
    ):
        """
        Args:
            dim: Model dimension
            num_experts: Number of expert networks
            expert_capacity: Max tokens per expert (for load balancing)
            k: Number of experts to route each token to
            hidden_dim: Expert hidden dimension (default: 4 * dim)
            dropout: Dropout rate
        """
        super().__init__()

        self.dim = dim
        self.num_experts = num_experts
        self.k = k
        self.expert_capacity = expert_capacity

        if hidden_dim is None:
            hidden_dim = 4 * dim

        # Create experts
        self.experts = nn.ModuleList([
            Expert(dim, hidden_dim, dropout) for _ in range(num_experts)
        ])

        # Router
        self.router = Router(dim, num_experts)

    def forward(self, x):
        """
        Route tokens to experts and combine outputs.

        Args:
            x: (batch, seq_len, dim)

        Returns:
            output: (batch, seq_len, dim)
            router_probs: Expert selection probabilities (for auxiliary loss)
        """
        batch_size, seq_len, dim = x.shape
        original_shape = x.shape

        # Reshape to (num_tokens, dim)
        x_flat = x.view(-1, dim)
        num_tokens = x_flat.shape[0]

        # Route tokens to experts
        router_logits = self.router(x)  # (num_tokens, num_experts)
        router_probs = F.softmax(router_logits, dim=-1)

        # Select top-k experts
        top_k_indices, top_k_weights = top_k_gating(router_logits, self.k)
        # top_k_indices: (num_tokens, k)
        # top_k_weights: (num_tokens, k)

        # Initialize output
        output = torch.zeros_like(x_flat)

        # Process each expert
        for expert_idx in range(self.num_experts):
            # Find tokens routed to this expert
            expert_mask = (top_k_indices == expert_idx)  # (num_tokens, k)

            # Get token indices and their routing weights
            token_indices, k_indices = torch.where(expert_mask)

            if len(token_indices) == 0:
                continue  # No tokens for this expert

            # Get tokens for this expert
            expert_input = x_flat[token_indices]  # (num_routed_tokens, dim)

            # Apply expert
            expert_output = self.experts[expert_idx](expert_input)

            # Get routing weights for these tokens
            expert_weights = top_k_weights[token_indices, k_indices]  # (num_routed_tokens,)

            # Add weighted expert output to result
            output[token_indices] += expert_weights.unsqueeze(-1) * expert_output

        # Reshape back to original shape
        output = output.view(original_shape)

        return output, router_probs


# Test MoE layer
moe = MixtureOfExperts(
    dim=512,
    num_experts=8,
    k=2,
    hidden_dim=2048
)

x = torch.randn(2, 10, 512)
output, router_probs = moe(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Router probs shape: {router_probs.shape}")

# Analyze expert usage
expert_usage = router_probs.sum(dim=0)
print(f"\nExpert usage (total routing probability):")
for i, usage in enumerate(expert_usage):
    print(f"  Expert {i}: {usage.item():.2f}")

Load Balancing

The Problem

Without constraints, some experts get overused while others are ignored:

python
def demonstrate_load_imbalance():
    """Show how routers can become imbalanced."""
    # Simulate imbalanced routing
    num_tokens = 1000
    num_experts = 8

    # Poorly balanced: most tokens go to expert 0 and 1
    imbalanced_probs = torch.zeros(num_tokens, num_experts)
    imbalanced_probs[:, 0] = 0.5  # 50% to expert 0
    imbalanced_probs[:, 1] = 0.4  # 40% to expert 1
    imbalanced_probs[:, 2:] = 0.1 / 6  # Remaining 10% split across others

    expert_loads = imbalanced_probs.sum(dim=0)

    import matplotlib.pyplot as plt

    plt.figure(figsize=(10, 6))
    plt.bar(range(num_experts), expert_loads.numpy())
    plt.xlabel('Expert Index')
    plt.ylabel('Total Load (tokens)')
    plt.title('Imbalanced Expert Utilization')
    plt.axhline(num_tokens / num_experts, color='r', linestyle='--',
                label='Ideal Balance')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

    print("Expert utilization (% of tokens):")
    for i, load in enumerate(expert_loads):
        print(f"  Expert {i}: {load / num_tokens * 100:.1f}%")

demonstrate_load_imbalance()

Solution: Auxiliary Loss

Encourage balanced expert usage:

python
def load_balancing_loss(router_probs, top_k_indices, num_experts):
    """
    Compute load balancing auxiliary loss.

    Encourages uniform distribution of tokens across experts.

    Args:
        router_probs: (num_tokens, num_experts) - Router probabilities
        top_k_indices: (num_tokens, k) - Selected expert indices
        num_experts: Number of experts

    Returns:
        loss: Scalar load balancing loss
    """
    num_tokens = router_probs.shape[0]

    # Compute fraction of tokens routed to each expert
    expert_counts = torch.zeros(num_experts, device=router_probs.device)
    for expert_idx in range(num_experts):
        expert_mask = (top_k_indices == expert_idx)
        expert_counts[expert_idx] = expert_mask.float().sum()

    fraction_routed = expert_counts / (num_tokens * top_k_indices.shape[1])

    # Compute average router probability for each expert
    avg_router_prob = router_probs.mean(dim=0)

    # Load balancing loss: encourage uniform distribution
    # Loss is high when fraction_routed and avg_router_prob differ across experts
    loss = num_experts * (fraction_routed * avg_router_prob).sum()

    return loss


# Test load balancing loss
router_probs = torch.randn(100, 8).softmax(dim=-1)
top_k_indices = router_probs.topk(2, dim=-1).indices

lb_loss = load_balancing_loss(router_probs, top_k_indices, num_experts=8)
print(f"Load balancing loss: {lb_loss.item():.4f}")

Switch Transformer Approach:

Google's Switch Transformer uses expert capacity limits: each expert can only process a maximum number of tokens. Excess tokens skip the MoE layer. This ensures balanced compute but can drop information.

Mixtral Approach:

Mixtral doesn't use hard capacity limits. Instead, it relies on the auxiliary loss and careful initialization to maintain balance.

Mistral Architecture

Mistral 7B is a dense model with innovations, while Mixtral 8x7B uses MoE:

python
class MistralBlock(nn.Module):
    """
    Mistral transformer block.

    Combines RMSNorm, SwiGLU, RoPE (like LLaMA) with optional MoE.
    """

    def __init__(
        self,
        dim,
        num_heads,
        num_kv_heads=None,
        hidden_dim=None,
        num_experts=None,  # If None, use dense FFN; else use MoE
        moe_k=2,
        dropout=0.0
    ):
        super().__init__()

        # Pre-normalization
        self.attention_norm = nn.RMSNorm(dim)  # Using PyTorch 2.0+
        self.ffn_norm = nn.RMSNorm(dim)

        # Attention (same as LLaMA with GQA)
        self.attention = MultiHeadAttention(dim, num_heads, num_kv_heads)

        # Feed-forward: MoE or dense
        if num_experts is not None:
            # Sparse MoE layer
            self.feed_forward = MixtureOfExperts(
                dim=dim,
                num_experts=num_experts,
                k=moe_k,
                hidden_dim=hidden_dim,
                dropout=dropout
            )
            self.use_moe = True
        else:
            # Dense FFN
            self.feed_forward = SwiGLU(dim, hidden_dim, dropout)
            self.use_moe = False

    def forward(self, x, mask=None):
        """
        Args:
            x: (batch, seq_len, dim)
            mask: Attention mask

        Returns:
            output: (batch, seq_len, dim)
            aux_loss: Auxiliary loss (load balancing) if using MoE
        """
        # Attention with residual
        h = x + self.attention(self.attention_norm(x), mask)

        # FFN with residual
        if self.use_moe:
            ffn_out, router_probs = self.feed_forward(self.ffn_norm(h))
            out = h + ffn_out

            # Compute auxiliary loss
            # (In practice, this would use the actual routing decisions)
            aux_loss = 0.01 * router_probs.var()  # Simplified
        else:
            out = h + self.feed_forward(self.ffn_norm(h))
            aux_loss = 0.0

        return out, aux_loss


# Mistral 7B (dense)
mistral_7b = MistralBlock(
    dim=4096,
    num_heads=32,
    num_kv_heads=8,  # GQA
    hidden_dim=14336,
    num_experts=None  # Dense
)

# Mixtral 8x7B (MoE)
mixtral_8x7b = MistralBlock(
    dim=4096,
    num_heads=32,
    num_kv_heads=8,
    hidden_dim=14336,
    num_experts=8,  # 8 experts
    moe_k=2  # Route to top-2
)

# Compare parameter counts
def count_params(model):
    return sum(p.numel() for p in model.parameters())

print(f"Mistral 7B block: {count_params(mistral_7b) / 1e6:.1f}M parameters")
print(f"Mixtral 8x7B block: {count_params(mixtral_8x7b) / 1e6:.1f}M parameters")
print(f"Mixtral has {count_params(mixtral_8x7b) / count_params(mistral_7b):.1f}x more parameters")
print(f"But only uses ~2/8 experts per token!")

Sliding Window Attention

Mistral also introduces sliding window attention for efficiency:

python
def create_sliding_window_mask(seq_len, window_size):
    """
    Create attention mask for sliding window attention.

    Each position can only attend to positions within a local window.

    Args:
        seq_len: Sequence length
        window_size: Size of attention window

    Returns:
        mask: (seq_len, seq_len) boolean mask
    """
    # Create distance matrix
    positions = torch.arange(seq_len).unsqueeze(0)
    distances = positions.T - positions

    # Mask out positions beyond window
    mask = (distances <= 0) & (distances >= -window_size)

    return mask


# Visualize sliding window
seq_len = 20
window_size = 4

mask = create_sliding_window_mask(seq_len, window_size)

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
plt.imshow(mask.numpy(), cmap='Blues', aspect='auto')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title(f'Sliding Window Attention (window={window_size})')
plt.colorbar(label='Can Attend')
plt.show()

print(f"Each token can attend to {window_size} previous tokens")
print(f"Receptive field after N layers: {window_size} × N")

Sliding Window Benefits:

  1. Linear memory: O(n × window) instead of O(n²)
  2. Long context: Stack layers to increase receptive field
  3. Local focus: Most attention is local anyway

Mistral uses window_size=4096, giving effective context of ~131k tokens with 32 layers.

Training MoE Models

Key considerations:

python
class MoETrainer:
    """Helper for training MoE models."""

    def __init__(self, model, aux_loss_weight=0.01):
        self.model = model
        self.aux_loss_weight = aux_loss_weight

    def training_step(self, batch):
        """
        Single training step with auxiliary loss.

        Args:
            batch: Input batch

        Returns:
            total_loss: Main loss + auxiliary loss
        """
        # Forward pass
        logits, aux_losses = self.model(batch['input_ids'])

        # Main loss (cross-entropy)
        main_loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            batch['labels'].view(-1)
        )

        # Auxiliary loss (load balancing)
        aux_loss = sum(aux_losses) if isinstance(aux_losses, list) else aux_losses

        # Total loss
        total_loss = main_loss + self.aux_loss_weight * aux_loss

        return {
            'total_loss': total_loss,
            'main_loss': main_loss,
            'aux_loss': aux_loss
        }

Summary

Mixture of Experts:

  • Multiple expert networks with learned routing
  • Sparse activation (only k out of n experts per token)
  • Dramatic capacity increase with manageable compute

Key Challenges:

  1. Load balancing: Prevent expert collapse
  2. Training instability: Routing can be unstable
  3. Communication: Expert parallelism requires careful engineering

Mistral Innovations:

  • Mistral 7B: Dense model with GQA and sliding window attention
  • Mixtral 8x7B: Sparse MoE achieving 47B parameter capacity with 13B compute per token

When to Use MoE:

  • Need large model capacity
  • Can tolerate training complexity
  • Have infrastructure for expert parallelism
  • Want better quality/compute tradeoff than dense models