Back
advanced
Cutting-Edge Topics

Mixture of Experts Deep Dive

Understand Mixture of Experts architecture, expert routing, and implementation in modern LLMs

30 min read· moe· architecture· experts· routing

Mixture of Experts Deep Dive

Master the Mixture of Experts architecture that powers models like Mixtral, GPT-4, and Switch Transformers, enabling massive scale with sparse activation.

What You'll Learn: MoE models activate only a subset of parameters for each input, dramatically increasing model capacity while keeping inference costs manageable. This is how models scale to trillions of parameters.

Understanding Mixture of Experts

Core Concepts

python
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional
import numpy as np

class SimpleExpert(nn.Module):
    """A single expert network (FFN)"""

    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.activation = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through expert"""
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        return x

class SimpleMoE(nn.Module):
    """
    Simple Mixture of Experts implementation

    Instead of one large FFN, we have N smaller expert FFNs.
    A router decides which experts process each token.
    """

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_experts: int = 8,
        top_k: int = 2
    ):
        super().__init__()

        self.num_experts = num_experts
        self.top_k = top_k

        # Create experts
        self.experts = nn.ModuleList([
            SimpleExpert(input_dim, hidden_dim, output_dim)
            for _ in range(num_experts)
        ])

        # Router network
        self.router = nn.Linear(input_dim, num_experts)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass with expert routing

        Args:
            x: Input tensor [batch_size, seq_len, input_dim]

        Returns:
            output: Mixed expert outputs [batch_size, seq_len, output_dim]
            router_logits: Router logits for load balancing loss
        """

        batch_size, seq_len, input_dim = x.shape

        # Flatten for routing
        x_flat = x.view(-1, input_dim)  # [batch_size * seq_len, input_dim]

        # Compute router logits
        router_logits = self.router(x_flat)  # [batch_size * seq_len, num_experts]

        # Get top-k experts for each token
        routing_weights, selected_experts = torch.topk(
            router_logits,
            self.top_k,
            dim=-1
        )

        # Normalize routing weights
        routing_weights = F.softmax(routing_weights, dim=-1)

        # Initialize output
        output = torch.zeros(
            batch_size * seq_len,
            self.experts[0].fc2.out_features,
            device=x.device
        )

        # Process each token through selected experts
        for i in range(self.top_k):
            # Get expert indices for this position
            expert_idx = selected_experts[:, i]

            # Process each expert
            for expert_num in range(self.num_experts):
                # Get tokens routed to this expert
                expert_mask = (expert_idx == expert_num)

                if expert_mask.any():
                    # Get inputs for this expert
                    expert_input = x_flat[expert_mask]

                    # Process through expert
                    expert_output = self.experts[expert_num](expert_input)

                    # Add weighted output
                    output[expert_mask] += (
                        routing_weights[expert_mask, i].unsqueeze(1) *
                        expert_output
                    )

        # Reshape output
        output = output.view(batch_size, seq_len, -1)

        return output, router_logits

    def load_balancing_loss(self, router_logits: torch.Tensor) -> torch.Tensor:
        """
        Compute load balancing loss to encourage equal expert usage

        Args:
            router_logits: Router output [batch_size * seq_len, num_experts]

        Returns:
            Load balancing loss
        """

        # Calculate the fraction of tokens routed to each expert
        routing_weights = F.softmax(router_logits, dim=-1)
        expert_usage = routing_weights.mean(dim=0)  # [num_experts]

        # Ideally, each expert should get 1/num_experts of the tokens
        target_usage = 1.0 / self.num_experts

        # Coefficient of variation squared
        cv_squared = (expert_usage.var() / (expert_usage.mean() ** 2))

        return cv_squared * self.num_experts

# Example usage
def demonstrate_moe():
    """Demonstrate MoE forward pass"""

    # Create simple MoE
    moe = SimpleMoE(
        input_dim=512,
        hidden_dim=2048,
        output_dim=512,
        num_experts=8,
        top_k=2
    )

    # Create sample input
    batch_size = 2
    seq_len = 10
    x = torch.randn(batch_size, seq_len, 512)

    # Forward pass
    output, router_logits = moe(x)

    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Router logits shape: {router_logits.shape}")

    # Calculate load balancing loss
    lb_loss = moe.load_balancing_loss(router_logits)
    print(f"Load balancing loss: {lb_loss.item():.4f}")

    # Analyze expert usage
    routing_weights = F.softmax(router_logits, dim=-1)
    expert_usage = routing_weights.mean(dim=0)

    print("\nExpert usage distribution:")
    for i, usage in enumerate(expert_usage):
        print(f"  Expert {i}: {usage.item()*100:.2f}%")

demonstrate_moe()

Advanced Router Mechanisms

Router Design: The router is critical in MoE - it determines which experts see which inputs. Different routing strategies offer tradeoffs between quality, efficiency, and load balancing.

Top-K Routing with Load Balancing

python
class AdvancedRouter(nn.Module):
    """Advanced router with multiple routing strategies"""

    def __init__(
        self,
        input_dim: int,
        num_experts: int,
        top_k: int = 2,
        noise_std: float = 0.1,
        expert_capacity_factor: float = 1.25
    ):
        super().__init__()

        self.num_experts = num_experts
        self.top_k = top_k
        self.noise_std = noise_std
        self.expert_capacity_factor = expert_capacity_factor

        self.router = nn.Linear(input_dim, num_experts)

    def forward(
        self,
        x: torch.Tensor,
        training: bool = True
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Route inputs to experts with capacity constraints

        Returns:
            routing_weights: Weights for selected experts
            selected_experts: Indices of selected experts
            router_logits: Raw router outputs
        """

        # Compute router logits
        router_logits = self.router(x)

        # Add noise during training for exploration
        if training and self.noise_std > 0:
            noise = torch.randn_like(router_logits) * self.noise_std
            router_logits = router_logits + noise

        # Top-k routing
        routing_weights, selected_experts = torch.topk(
            router_logits,
            self.top_k,
            dim=-1
        )

        # Softmax over selected experts
        routing_weights = F.softmax(routing_weights, dim=-1)

        return routing_weights, selected_experts, router_logits

    def compute_auxiliary_loss(
        self,
        router_logits: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute auxiliary loss for load balancing

        Combines:
        1. Load balancing loss (encourage equal expert usage)
        2. Router z-loss (encourage router confidence)
        """

        # Load balancing loss
        routing_weights = F.softmax(router_logits, dim=-1)
        expert_usage = routing_weights.mean(dim=0)

        # Importance: sum of probabilities for each expert
        importance = expert_usage

        # Load: fraction of tokens routed to each expert (for top-k)
        _, selected = torch.topk(router_logits, self.top_k, dim=-1)
        load = torch.zeros(self.num_experts, device=router_logits.device)

        for i in range(self.num_experts):
            load[i] = (selected == i).float().sum() / selected.numel()

        # Load balancing loss: importance * load
        load_balancing_loss = (importance * load).sum() * self.num_experts

        # Router z-loss: encourages router to be confident
        router_z_loss = torch.logsumexp(router_logits, dim=-1).pow(2).mean()

        return load_balancing_loss + 0.01 * router_z_loss

class SwitchRouter(nn.Module):
    """
    Switch Transformer routing (top-1 with capacity)

    Switch routing uses only top-1 expert per token but with
    capacity constraints to prevent overloading popular experts.
    """

    def __init__(
        self,
        input_dim: int,
        num_experts: int,
        capacity_factor: float = 1.25,
        jitter_noise: float = 0.01
    ):
        super().__init__()

        self.num_experts = num_experts
        self.capacity_factor = capacity_factor
        self.jitter_noise = jitter_noise

        self.router = nn.Linear(input_dim, num_experts, bias=False)

    def forward(
        self,
        x: torch.Tensor,
        training: bool = True
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Switch routing with capacity constraints

        Returns:
            dispatch_tensor: Sparse tensor for routing [tokens, experts]
            combine_tensor: Sparse tensor for combining [experts, tokens]
            aux_loss: Auxiliary loss for load balancing
        """

        num_tokens = x.shape[0]

        # Compute router logits
        router_logits = self.router(x)

        # Add jitter noise during training
        if training:
            router_logits = router_logits + torch.randn_like(router_logits) * self.jitter_noise

        # Calculate expert capacity
        expert_capacity = int(
            self.capacity_factor * num_tokens / self.num_experts
        )

        # Get top-1 expert per token
        router_probs = F.softmax(router_logits, dim=-1)
        expert_weights, expert_indices = torch.max(router_probs, dim=-1)

        # Create dispatch and combine tensors
        dispatch_tensor = torch.zeros(
            num_tokens,
            self.num_experts,
            device=x.device
        )

        combine_tensor = torch.zeros(
            num_tokens,
            self.num_experts,
            device=x.device
        )

        # Track expert usage
        expert_counts = torch.zeros(self.num_experts, device=x.device)

        # Assign tokens to experts (respecting capacity)
        for token_idx in range(num_tokens):
            expert_idx = expert_indices[token_idx].item()

            # Check capacity
            if expert_counts[expert_idx] < expert_capacity:
                dispatch_tensor[token_idx, expert_idx] = 1.0
                combine_tensor[token_idx, expert_idx] = expert_weights[token_idx]
                expert_counts[expert_idx] += 1

        # Compute auxiliary loss
        aux_loss = self._compute_load_balancing_loss(router_probs, expert_indices)

        return dispatch_tensor, combine_tensor, aux_loss

    def _compute_load_balancing_loss(
        self,
        router_probs: torch.Tensor,
        expert_indices: torch.Tensor
    ) -> torch.Tensor:
        """Compute load balancing loss"""

        num_tokens = router_probs.shape[0]

        # Fraction of tokens routed to each expert
        tokens_per_expert = torch.bincount(
            expert_indices,
            minlength=self.num_experts
        ).float() / num_tokens

        # Fraction of router probability for each expert
        prob_per_expert = router_probs.mean(dim=0)

        # Load balancing loss
        aux_loss = self.num_experts * (tokens_per_expert * prob_per_expert).sum()

        return aux_loss

# Example usage
def demonstrate_routing():
    """Demonstrate different routing strategies"""

    input_dim = 512
    num_experts = 8
    num_tokens = 100

    # Create sample input
    x = torch.randn(num_tokens, input_dim)

    # Advanced router (top-k)
    print("Top-K Routing:")
    router_topk = AdvancedRouter(input_dim, num_experts, top_k=2)
    weights, experts, logits = router_topk(x, training=True)
    aux_loss = router_topk.compute_auxiliary_loss(logits)

    print(f"  Selected experts shape: {experts.shape}")
    print(f"  Routing weights shape: {weights.shape}")
    print(f"  Auxiliary loss: {aux_loss.item():.4f}")

    # Switch router (top-1)
    print("\nSwitch Routing:")
    router_switch = SwitchRouter(input_dim, num_experts)
    dispatch, combine, aux_loss = router_switch(x, training=True)

    print(f"  Dispatch tensor shape: {dispatch.shape}")
    print(f"  Combine tensor shape: {combine.shape}")
    print(f"  Auxiliary loss: {aux_loss.item():.4f}")

    # Analyze expert usage
    tokens_per_expert = dispatch.sum(dim=0)
    print("\n  Tokens per expert:")
    for i, count in enumerate(tokens_per_expert):
        print(f"    Expert {i}: {count.item():.0f} tokens")

demonstrate_routing()

Switch Transformer Implementation

Switch Transformers: Google's Switch Transformer uses MoE with top-1 routing to scale to trillions of parameters while maintaining efficiency through sparse activation.

python
class SwitchTransformerLayer(nn.Module):
    """
    Complete Switch Transformer layer with MoE FFN
    """

    def __init__(
        self,
        d_model: int,
        num_experts: int,
        expert_hidden_dim: int,
        num_heads: int = 8,
        capacity_factor: float = 1.25,
        dropout: float = 0.1
    ):
        super().__init__()

        self.d_model = d_model
        self.num_experts = num_experts

        # Self-attention (standard)
        self.self_attn = nn.MultiheadAttention(
            d_model,
            num_heads,
            dropout=dropout,
            batch_first=True
        )

        # Layer norms
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # MoE FFN
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, expert_hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(expert_hidden_dim, d_model)
            )
            for _ in range(num_experts)
        ])

        # Router
        self.router = SwitchRouter(
            d_model,
            num_experts,
            capacity_factor=capacity_factor
        )

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        training: bool = True
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass through Switch Transformer layer

        Args:
            x: Input tensor [batch_size, seq_len, d_model]

        Returns:
            output: Layer output
            aux_loss: Auxiliary loss for load balancing
        """

        batch_size, seq_len, d_model = x.shape

        # Self-attention block
        residual = x
        x = self.norm1(x)
        attn_output, _ = self.self_attn(x, x, x)
        x = residual + self.dropout(attn_output)

        # MoE FFN block
        residual = x
        x = self.norm2(x)

        # Flatten for routing
        x_flat = x.view(-1, d_model)  # [batch_size * seq_len, d_model]

        # Route to experts
        dispatch_tensor, combine_tensor, aux_loss = self.router(
            x_flat,
            training=training
        )

        # Process through experts
        expert_outputs = []

        for expert_idx, expert in enumerate(self.experts):
            # Get tokens assigned to this expert
            expert_mask = dispatch_tensor[:, expert_idx] > 0

            if expert_mask.any():
                # Process through expert
                expert_input = x_flat[expert_mask]
                expert_output = expert(expert_input)

                # Create full output tensor (zeros for non-assigned tokens)
                full_expert_output = torch.zeros_like(x_flat)
                full_expert_output[expert_mask] = expert_output

                expert_outputs.append(full_expert_output)
            else:
                expert_outputs.append(torch.zeros_like(x_flat))

        # Combine expert outputs
        expert_outputs = torch.stack(expert_outputs, dim=-1)  # [..., num_experts]
        combined_output = (expert_outputs * combine_tensor.unsqueeze(1)).sum(dim=-1)

        # Reshape
        combined_output = combined_output.view(batch_size, seq_len, d_model)

        # Add residual
        output = residual + self.dropout(combined_output)

        return output, aux_loss

class SwitchTransformer(nn.Module):
    """Complete Switch Transformer model"""

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        num_layers: int = 6,
        num_experts: int = 8,
        expert_hidden_dim: int = 2048,
        num_heads: int = 8,
        max_seq_len: int = 512,
        dropout: float = 0.1,
        aux_loss_weight: float = 0.01
    ):
        super().__init__()

        self.aux_loss_weight = aux_loss_weight

        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)

        # Transformer layers
        self.layers = nn.ModuleList([
            SwitchTransformerLayer(
                d_model,
                num_experts,
                expert_hidden_dim,
                num_heads,
                dropout=dropout
            )
            for _ in range(num_layers)
        ])

        # Output head
        self.norm = nn.LayerNorm(d_model)
        self.output = nn.Linear(d_model, vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        training: bool = True
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass

        Args:
            input_ids: Input token IDs [batch_size, seq_len]

        Returns:
            logits: Output logits [batch_size, seq_len, vocab_size]
            total_aux_loss: Total auxiliary loss
        """

        batch_size, seq_len = input_ids.shape

        # Create embeddings
        token_embeds = self.token_embedding(input_ids)

        positions = torch.arange(seq_len, device=input_ids.device)
        position_embeds = self.position_embedding(positions)

        x = token_embeds + position_embeds

        # Process through layers
        total_aux_loss = 0.0

        for layer in self.layers:
            x, aux_loss = layer(x, training=training)
            total_aux_loss = total_aux_loss + aux_loss

        # Output projection
        x = self.norm(x)
        logits = self.output(x)

        # Scale auxiliary loss
        total_aux_loss = total_aux_loss * self.aux_loss_weight

        return logits, total_aux_loss

# Example usage
def train_switch_transformer():
    """Demonstrate Switch Transformer training"""

    # Create model
    model = SwitchTransformer(
        vocab_size=10000,
        d_model=512,
        num_layers=4,
        num_experts=8,
        expert_hidden_dim=2048
    )

    # Create sample batch
    batch_size = 4
    seq_len = 32
    input_ids = torch.randint(0, 10000, (batch_size, seq_len))

    # Forward pass
    logits, aux_loss = model(input_ids, training=True)

    print(f"Input shape: {input_ids.shape}")
    print(f"Output shape: {logits.shape}")
    print(f"Auxiliary loss: {aux_loss.item():.4f}")

    # Calculate total loss (including aux loss)
    targets = torch.randint(0, 10000, (batch_size, seq_len))

    ce_loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        targets.view(-1)
    )

    total_loss = ce_loss + aux_loss

    print(f"Cross-entropy loss: {ce_loss.item():.4f}")
    print(f"Total loss: {total_loss.item():.4f}")

train_switch_transformer()

Expert Specialization Analysis

Expert Specialization: In trained MoE models, experts often specialize in different types of inputs (e.g., math, code, general knowledge). Understanding this helps with debugging and optimization.

python
class ExpertAnalyzer:
    """Analyze expert behavior and specialization"""

    def __init__(self, model: SwitchTransformer):
        self.model = model
        self.expert_activations = {i: [] for i in range(model.layers[0].num_experts)}
        self.expert_inputs = {i: [] for i in range(model.layers[0].num_experts)}

    def analyze_routing(
        self,
        input_ids: torch.Tensor,
        input_texts: List[str]
    ):
        """Analyze which experts are activated for which inputs"""

        # Hook to capture routing decisions
        routing_decisions = []

        def routing_hook(module, input, output):
            dispatch_tensor, _, _ = output
            routing_decisions.append(dispatch_tensor)

        # Register hooks
        hooks = []
        for layer in self.model.layers:
            hook = layer.router.register_forward_hook(routing_hook)
            hooks.append(hook)

        # Forward pass
        with torch.no_grad():
            self.model(input_ids, training=False)

        # Remove hooks
        for hook in hooks:
            hook.remove()

        # Analyze routing patterns
        for layer_idx, dispatch in enumerate(routing_decisions):
            print(f"\nLayer {layer_idx} routing:")

            # dispatch: [batch_size * seq_len, num_experts]
            tokens_per_expert = dispatch.sum(dim=0)

            for expert_idx, count in enumerate(tokens_per_expert):
                if count > 0:
                    print(f"  Expert {expert_idx}: {count.item():.0f} tokens")

    def analyze_expert_specialization(
        self,
        dataset: List[Tuple[torch.Tensor, str, str]]
    ):
        """
        Analyze what types of inputs each expert specializes in

        Args:
            dataset: List of (input_ids, text, category) tuples
        """

        # Track which experts activate for which categories
        expert_categories = {i: [] for i in range(8)}

        for input_ids, text, category in dataset:
            # Get routing decisions
            routing_decisions = []

            def routing_hook(module, input, output):
                dispatch_tensor, _, _ = output
                routing_decisions.append(dispatch_tensor)

            hooks = []
            for layer in self.model.layers:
                hook = layer.router.register_forward_hook(routing_hook)
                hooks.append(hook)

            with torch.no_grad():
                self.model(input_ids.unsqueeze(0), training=False)

            for hook in hooks:
                hook.remove()

            # Record which experts were used
            dispatch = routing_decisions[0]  # First layer
            active_experts = dispatch.sum(dim=0) > 0

            for expert_idx, is_active in enumerate(active_experts):
                if is_active:
                    expert_categories[expert_idx].append(category)

        # Analyze specialization
        print("\nExpert Specialization Analysis:")
        for expert_idx, categories in expert_categories.items():
            if categories:
                # Count category frequency
                from collections import Counter
                category_counts = Counter(categories)

                print(f"\nExpert {expert_idx}:")
                for cat, count in category_counts.most_common(3):
                    pct = count / len(categories) * 100
                    print(f"  {cat}: {pct:.1f}%")

    def visualize_expert_usage(self, routing_matrix: torch.Tensor):
        """Visualize expert usage patterns"""

        import matplotlib.pyplot as plt

        # routing_matrix: [num_tokens, num_experts]
        expert_usage = routing_matrix.sum(dim=0).cpu().numpy()

        plt.figure(figsize=(10, 6))
        plt.bar(range(len(expert_usage)), expert_usage)
        plt.xlabel('Expert Index')
        plt.ylabel('Number of Tokens')
        plt.title('Expert Usage Distribution')
        plt.show()

        return expert_usage

# Example usage
model = SwitchTransformer(vocab_size=10000, num_experts=8)
analyzer = ExpertAnalyzer(model)

# Analyze routing
input_ids = torch.randint(0, 10000, (2, 20))
texts = ["Sample text 1", "Sample text 2"]
analyzer.analyze_routing(input_ids, texts)

Production Optimization

python
class OptimizedMoE(nn.Module):
    """Production-optimized MoE with various efficiency improvements"""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_experts: int = 8,
        top_k: int = 2,
        use_kernel_fusion: bool = True
    ):
        super().__init__()

        self.num_experts = num_experts
        self.top_k = top_k

        # Batched experts (more efficient than separate modules)
        self.expert_weights_1 = nn.Parameter(
            torch.randn(num_experts, input_dim, hidden_dim)
        )
        self.expert_weights_2 = nn.Parameter(
            torch.randn(num_experts, hidden_dim, output_dim)
        )

        self.router = nn.Linear(input_dim, num_experts)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Optimized forward pass"""

        batch_size, seq_len, input_dim = x.shape

        # Route
        x_flat = x.view(-1, input_dim)
        router_logits = self.router(x_flat)

        routing_weights, selected_experts = torch.topk(
            router_logits,
            self.top_k,
            dim=-1
        )

        routing_weights = F.softmax(routing_weights, dim=-1)

        # Batched expert computation (more efficient)
        outputs = []

        for k in range(self.top_k):
            expert_idx = selected_experts[:, k]

            # Gather expert weights
            weights_1 = self.expert_weights_1[expert_idx]  # [batch, input_dim, hidden_dim]
            weights_2 = self.expert_weights_2[expert_idx]  # [batch, hidden_dim, output_dim]

            # Batch matrix multiply
            hidden = torch.bmm(
                x_flat.unsqueeze(1),
                weights_1
            ).squeeze(1)  # [batch, hidden_dim]

            hidden = F.gelu(hidden)

            output = torch.bmm(
                hidden.unsqueeze(1),
                weights_2
            ).squeeze(1)  # [batch, output_dim]

            # Weight by routing probability
            output = output * routing_weights[:, k].unsqueeze(1)
            outputs.append(output)

        # Combine
        final_output = torch.stack(outputs, dim=0).sum(dim=0)

        return final_output.view(batch_size, seq_len, -1)

# Benchmark
def benchmark_moe():
    """Benchmark MoE performance"""
    import time

    model = OptimizedMoE(512, 2048, 512, num_experts=8)

    x = torch.randn(4, 100, 512)

    # Warmup
    for _ in range(10):
        _ = model(x)

    # Benchmark
    start = time.time()
    for _ in range(100):
        _ = model(x)
    elapsed = time.time() - start

    print(f"Average time per forward pass: {elapsed/100*1000:.2f}ms")

benchmark_moe()

Quiz

Test your understanding of Mixture of Experts:

Summary

In this lesson, you learned:

  • MoE fundamentals: Sparse activation and expert routing
  • Router mechanisms: Top-k routing, load balancing, and capacity constraints
  • Switch Transformers: Google's efficient top-1 MoE architecture
  • Expert specialization: How experts develop specializations during training
  • Production optimization: Efficient MoE implementation techniques

Mixture of Experts is a key technique enabling the next generation of massive language models while keeping inference costs manageable.