Back
advanced
Advanced Transformer Concepts

Paper: Scaling Laws for Neural Language Models

Deep dive into the OpenAI paper that revealed how language model performance scales predictably with model size, dataset size, and compute.

25 min read· Paper· Scaling Laws· Research· Training

Paper: Scaling Laws for Neural Language Models

Scaling Laws for Neural Language Models

Jared Kaplan, Sam McCandlish, Tom Henighan, Tom B. Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, Dario Amodei ()

Read Paper

In January 2020, OpenAI published research that transformed how we think about training large language models. Instead of guesswork, they provided mathematical laws predicting how performance scales with model size, data, and compute.

Motivation and Context

The Questions

Before this paper, LLM development faced key uncertainties:

  1. How big should we make our model?
  2. How much data do we need?
  3. How much will it cost to achieve X performance?
  4. Should we train a bigger model or a smaller one for longer?

The paper's answer: All of these follow predictable power laws.

Power Laws in Nature:

Power laws appear throughout nature and science:

  • City populations
  • Earthquake magnitudes
  • Word frequencies (Zipf's law)

The paper showed LLM performance is another instance of this fundamental pattern.

Experimental Setup

Models Tested

The authors trained hundreds of models varying in size:

  • Smallest: ~768 parameters
  • Largest: 1.5 billion parameters (GPT-2 scale)
  • Architecture: Transformer decoder (GPT-style)
  • Total runs: Over 300 training runs

Training Details

python
"""
Paper's experimental configuration:

Architectures tested:
- n_layers: 2 to 64
- d_model: 128 to 1024
- n_heads: 1 to 128
- Batch size: 256 to 524,288 tokens

Dataset:
- WebText2 (OpenAI's web scrape)
- Filtered and deduplicated
- Diverse range of internet text

Compute:
- Total: ~2.5 × 10^23 FLOPs
- Hardware: V100 GPUs
"""

# Example: Recreate their model configurations
def paper_model_config(scale='small'):
    """Model configurations from the paper."""
    configs = {
        'tiny': {
            'n_layers': 4,
            'd_model': 128,
            'n_heads': 2,
            'd_ff': 512,
            'params': 0.8e6,  # ~1M
        },
        'small': {
            'n_layers': 12,
            'd_model': 768,
            'n_heads': 12,
            'd_ff': 3072,
            'params': 117e6,  # ~117M
        },
        'medium': {
            'n_layers': 24,
            'd_model': 1024,
            'n_heads': 16,
            'd_ff': 4096,
            'params': 350e6,  # ~350M
        },
        'large': {
            'n_layers': 36,
            'd_model': 1280,
            'n_heads': 20,
            'd_ff': 5120,
            'params': 762e6,  # ~762M
        },
        'xl': {
            'n_layers': 48,
            'd_model': 1600,
            'n_heads': 25,
            'd_ff': 6400,
            'params': 1.5e9,  # ~1.5B
        }
    }
    return configs[scale]


import json
for scale in ['tiny', 'small', 'medium', 'large', 'xl']:
    config = paper_model_config(scale)
    print(f"{scale.upper():8s}: {config['params']/1e6:6.0f}M params, "
          f"{config['n_layers']} layers, d_model={config['d_model']}")

Key Finding 1: Power Law Scaling

The Core Result

Loss scales as a power law with each of three factors:

python
import numpy as np
import matplotlib.pyplot as plt

class ScalingLawsFromPaper:
    """
    Implement the exact formulas from the paper.
    """

    def __init__(self):
        # Fitted constants (Table 1 from paper)
        self.Nc = 8.8e13      # Critical model size
        self.Dc = 5.4e13      # Critical dataset size
        self.Cc = 3.1e8       # Critical compute

        # Power law exponents
        self.alpha_N = 0.076  # Model size
        self.alpha_D = 0.095  # Dataset size
        self.alpha_C = 0.050  # Compute

    def L_N(self, N):
        """
        Loss as a function of model size (parameters).

        Formula from paper:
        L(N) = (Nc / N)^αN for N >> Nc

        In practice, includes a constant floor:
        L(N) ≈ (Nc / N)^0.076 + L_floor
        """
        if N < 1e3:
            return float('inf')

        return (self.Nc / N) ** self.alpha_N

    def L_D(self, D):
        """
        Loss as a function of dataset size (tokens).

        Formula: L(D) = (Dc / D)^αD
        """
        if D < 1e3:
            return float('inf')

        return (self.Dc / D) ** self.alpha_D

    def L_C(self, C):
        """
        Loss as a function of compute budget (FLOPs).

        Formula: L(C) = (Cc / C)^αC
        """
        if C < 1e6:
            return float('inf')

        return (self.Cc / C) ** self.alpha_C


# Reproduce Figure 1 from the paper
scaling = ScalingLawsFromPaper()

# Parameter sweep
N_range = np.logspace(3, 12, 100)
losses_N = [scaling.L_N(N) for N in N_range]

# Data sweep
D_range = np.logspace(3, 12, 100)
losses_D = [scaling.L_D(D) for D in D_range]

# Compute sweep
C_range = np.logspace(12, 24, 100)
losses_C = [scaling.L_C(C) for C in C_range]

# Plot (log-log scale shows power law as straight line)
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: Model size
axes[0].loglog(N_range, losses_N, 'b-', linewidth=2)
axes[0].set_xlabel('Parameters (N)', fontsize=12)
axes[0].set_ylabel('Test Loss', fontsize=12)
axes[0].set_title('Scaling with Model Size\nL(N) ∝ N^-0.076', fontsize=14)
axes[0].grid(True, alpha=0.3)
axes[0].axvline(117e6, color='r', linestyle='--', alpha=0.5, label='GPT-2 Small')
axes[0].axvline(1.5e9, color='g', linestyle='--', alpha=0.5, label='GPT-2 XL')
axes[0].legend()

# Plot 2: Dataset size
axes[1].loglog(D_range, losses_D, 'g-', linewidth=2)
axes[1].set_xlabel('Dataset Size (tokens)', fontsize=12)
axes[1].set_ylabel('Test Loss', fontsize=12)
axes[1].set_title('Scaling with Data\nL(D) ∝ D^-0.095', fontsize=14)
axes[1].grid(True, alpha=0.3)

# Plot 3: Compute
axes[2].loglog(C_range, losses_C, 'r-', linewidth=2)
axes[2].set_xlabel('Compute (FLOPs)', fontsize=12)
axes[2].set_ylabel('Test Loss', fontsize=12)
axes[2].set_title('Scaling with Compute\nL(C) ∝ C^-0.050', fontsize=14)
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('scaling_laws_figure1.png', dpi=150, bbox_inches='tight')
plt.show()

print("Power law exponents:")
print(f"  Model size (N): α = {scaling.alpha_N}")
print(f"  Data size (D):  α = {scaling.alpha_D}")
print(f"  Compute (C):    α = {scaling.alpha_C}")

Reading Log-Log Plots:

On a log-log plot, a power law appears as a straight line. The slope of the line is the exponent α. Steeper slopes mean faster scaling.

Key Finding 2: No Overfitting Plateau

Surprising Result

Models don't overfit in the traditional sense, even when trained far beyond "convergence."

python
def demonstrate_no_overfitting():
    """
    Show that test loss continues to improve with more training,
    unlike traditional overfitting curves.
    """
    # Traditional overfitting (NOT what the paper found)
    epochs_traditional = np.linspace(0, 100, 100)
    train_loss_traditional = 3.0 * np.exp(-epochs_traditional / 20)
    test_loss_traditional = 3.0 * np.exp(-epochs_traditional / 20) + \
                           0.5 * (epochs_traditional / 100) ** 2

    # LLM behavior (what the paper found)
    epochs_llm = np.linspace(0, 100, 100)
    train_loss_llm = 3.0 * np.exp(-epochs_llm / 30)
    test_loss_llm = 3.0 * np.exp(-epochs_llm / 30) + 0.01  # Small gap, no upturn

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    # Traditional overfitting
    ax1.plot(epochs_traditional, train_loss_traditional, 'b-', label='Train Loss')
    ax1.plot(epochs_traditional, test_loss_traditional, 'r-', label='Test Loss')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.set_title('Traditional Overfitting\n(What we expected)')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # LLM behavior
    ax2.plot(epochs_llm, train_loss_llm, 'b-', label='Train Loss')
    ax2.plot(epochs_llm, test_loss_llm, 'r-', label='Test Loss')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Loss')
    ax2.set_title('LLM Behavior\n(What the paper found)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

demonstrate_no_overfitting()

Implication: You can train much longer than you think without hurting generalization.

Why No Overfitting?

The paper hypothesizes:

  1. Massive capacity: Models are underparameterized relative to data complexity
  2. Language structure: Natural language has immense complexity that models haven't saturated
  3. Implicit regularization: Gradient descent implicitly regularizes

This changed the paradigm from "stop before overfitting" to "train as long as compute allows."

Key Finding 3: Optimal Compute Allocation

The Compute Equation

For a given compute budget C, how should you allocate between model size N and data D?

Finding: There's an optimal N*(C) and D*(C) that minimizes loss.

python
def optimal_compute_allocation(C):
    """
    Given compute budget C, find optimal model size and dataset size.

    From the paper (Equation 5.1):
    N_opt(C) ∝ C^a
    D_opt(C) ∝ C^b

    where a ≈ 0.73, b ≈ 0.27

    Note: Later work (Chinchilla) revised these to a ≈ b ≈ 0.5
    """
    # Constants from the paper
    a = 0.73  # Exponent for parameters
    b = 0.27  # Exponent for data

    # FLOPs per token: C ≈ 6ND
    # Solve for N and D given C and the power law relationships

    # Simplified optimal allocation
    N_opt = (C / 6) ** a
    D_opt = C / (6 * N_opt)

    return N_opt, D_opt


# Compare different compute budgets
compute_budgets = [1e18, 1e20, 1e22, 1e24]

print("Optimal allocation per compute budget:")
print(f"{'Compute (FLOPs)':&lt;20s} {'Params':&lt;15s} {'Tokens':&lt;15s} {'Ratio (T/P)':&lt;10s}")
print("-" * 70)

for C in compute_budgets:
    N_opt, D_opt = optimal_compute_allocation(C)
    ratio = D_opt / N_opt
    print(f"{C:&lt;20.1e} {N_opt:&lt;15.2e} {D_opt:&lt;15.2e} {ratio:&lt;10.1f}")

# Key insight from table
print("\nKey insight: As compute increases, parameters should grow faster than data")
print("This was later challenged by Chinchilla (2022)!")

Key Finding 4: Convergence is Fast

Early Stopping is Optimal

Result: Most of the performance comes in the first epoch. Continued training yields diminishing returns.

python
def training_efficiency_over_time():
    """
    Show how loss improves during training (Figure 3 from paper).
    """
    # Simulate training curve
    steps = np.linspace(0, 1000, 1000)

    # Loss improves quickly then plateaus
    # L(S) ≈ L_∞ + (L_0 - L_∞) * (Sc / (S + Sc))
    L_0 = 3.5        # Initial loss
    L_inf = 2.2      # Final loss
    Sc = 100         # Critical step count

    loss = L_inf + (L_0 - L_inf) * (Sc / (steps + Sc))

    plt.figure(figsize=(10, 6))
    plt.plot(steps, loss, 'b-', linewidth=2)
    plt.axhline(L_inf, color='r', linestyle='--', alpha=0.5, label='Final Loss')
    plt.axvline(Sc, color='g', linestyle='--', alpha=0.5, label='Critical Steps')
    plt.xlabel('Training Steps', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title('Training Dynamics: Most Progress is Early', fontsize=14)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

    # Calculate how much improvement at different points
    def loss_at_step(S):
        return L_inf + (L_0 - L_inf) * (Sc / (S + Sc))

    total_improvement = L_0 - L_inf
    improvement_at_100 = L_0 - loss_at_step(100)
    improvement_at_1000 = L_0 - loss_at_step(1000)

    print(f"Total possible improvement: {total_improvement:.2f}")
    print(f"Improvement by step 100: {improvement_at_100:.2f} ({improvement_at_100/total_improvement*100:.1f}%)")
    print(f"Improvement by step 1000: {improvement_at_1000:.2f} ({improvement_at_1000/total_improvement*100:.1f}%)")

training_efficiency_over_time()

Training Implication:

The paper suggests training compute-efficiently: don't train beyond the point where loss improvements slow dramatically. This informed the "one epoch" training common in modern LLMs.

Key Finding 5: Shape Doesn't Matter Much

Width vs. Depth

Result: Model shape (depth vs. width) matters much less than total parameter count.

python
def compare_shapes(total_params=125e6):
    """
    Compare different model shapes with same parameter count.

    The paper showed these perform nearly identically.
    """
    shapes = [
        {'name': 'Wide & Shallow', 'layers': 6, 'd_model': 1280},
        {'name': 'Balanced', 'layers': 12, 'd_model': 896},
        {'name': 'Deep & Narrow', 'layers': 24, 'd_model': 640},
    ]

    scaling = ScalingLawsFromPaper()

    print(f"All models ~{total_params/1e6:.0f}M parameters:\n")
    print(f"{'Shape':&lt;20s} {'Layers':&lt;10s} {'d_model':&lt;10s} {'Predicted Loss':&lt;15s}")
    print("-" * 65)

    for shape in shapes:
        # All have same N, so same predicted loss
        predicted_loss = scaling.L_N(total_params)
        print(f"{shape['name']:&lt;20s} {shape['layers']:&lt;10d} {shape['d_model']:&lt;10d} {predicted_loss:&lt;15.4f}")

    print("\nConclusion: Shape matters little - total parameters dominate!")

compare_shapes()

Impact and Legacy

What Changed

Before the paper:

  • Model development was mostly guesswork
  • Unclear whether to scale depth, width, or data
  • No way to predict performance without training

After the paper:

  • Systematic approach to model design
  • Predictable performance scaling
  • Informed GPT-3, Gopher, and other large models
  • Later refined by Chinchilla (2022)

The GPT-3 Connection

python
def predict_gpt3_performance():
    """
    The scaling laws paper directly informed GPT-3's design.

    GPT-3 paper (2020) cited scaling laws to justify the 175B model.
    """
    scaling = ScalingLawsFromPaper()

    # GPT-3 configuration
    gpt3_params = 175e9
    gpt3_tokens = 300e9
    gpt3_compute = 6 * gpt3_params * gpt3_tokens

    # Predictions
    loss_from_params = scaling.L_N(gpt3_params)
    loss_from_data = scaling.L_D(gpt3_tokens)
    loss_from_compute = scaling.L_C(gpt3_compute)

    print("GPT-3 Scaling Laws Predictions:")
    print(f"  Parameters: {gpt3_params/1e9:.0f}B")
    print(f"  Tokens: {gpt3_tokens/1e9:.0f}B")
    print(f"  Compute: {gpt3_compute:.2e} FLOPs\n")

    print("Predicted Loss:")
    print(f"  From parameters: {loss_from_params:.4f}")
    print(f"  From data: {loss_from_data:.4f}")
    print(f"  From compute: {loss_from_compute:.4f}")

    print("\nNote: These predictions are approximate.")
    print("Actual GPT-3 performance was better due to improved architecture.")

predict_gpt3_performance()

Historical Context:

This paper enabled the "scaling is all you need" era of 2020-2022, leading to GPT-3, Gopher, Megatron-Turing NLG, and other massive models. It made training 100B+ parameter models scientifically justified rather than a gamble.

Limitations and Critiques

What the Paper Got Wrong

1. Optimal compute allocation was revised: The paper suggested N should scale faster than D (exponents 0.73 vs 0.27). Chinchilla (2022) showed they should scale equally (both ~0.5).

2. Architecture improvements: The power laws assume fixed architecture. Improvements like Flash Attention, better activations, etc. can shift the curves.

3. Emergent abilities: The smooth power laws don't predict discrete capability jumps that appear at certain scales.

python
# The Chinchilla correction
def compare_allocations(compute_budget):
    """Compare original scaling laws vs. Chinchilla allocation."""

    # Original paper
    N_orig = (compute_budget / 6) ** 0.73
    D_orig = compute_budget / (6 * N_orig)

    # Chinchilla
    N_chinch = (compute_budget / 6) ** 0.50
    D_chinch = compute_budget / (6 * N_chinch)

    print(f"For compute budget {compute_budget:.2e} FLOPs:\n")
    print(f"{'Method':&lt;20s} {'Parameters':&lt;20s} {'Tokens':&lt;20s} {'Token/Param':&lt;10s}")
    print("-" * 75)
    print(f"{'Original (2020)':&lt;20s} {N_orig:&lt;20.2e} {D_orig:&lt;20.2e} {D_orig/N_orig:&lt;10.1f}")
    print(f"{'Chinchilla (2022)':&lt;20s} {N_chinch:&lt;20.2e} {D_chinch:&lt;20.2e} {D_chinch/N_chinch:&lt;10.1f}")

    print("\nChinchilla uses smaller models trained longer!")

compare_allocations(3.14e23)  # GPT-3 compute budget

Summary

The "Scaling Laws for Neural Language Models" paper established that:

  1. Performance scales predictably with parameters, data, and compute as power laws
  2. No overfitting plateau - models can train much longer than expected
  3. Optimal allocation exists for compute budgets (later refined by Chinchilla)
  4. Convergence is fast - most improvement comes early in training
  5. Shape matters little - total parameters dominate over architecture choices

Legacy:

  • Enabled the scaling era (GPT-3, Gopher, etc.)
  • Provided scientific foundation for $10M+ training runs
  • Later refined but core insights remain valid

Quote from the paper:

"Larger models are significantly more sample-efficient, such that optimally compute-efficient training involves training very large models on a relatively modest amount of data and stopping significantly before convergence."

This insight shaped modern LLM development, even as later work (Chinchilla) refined the exact ratios.