Back
advanced
Modern Architectures

FlashAttention: Fast and Memory-Efficient Attention

Understand how FlashAttention achieves 2-4x speedup and massive memory savings through IO-aware algorithm design and tiling techniques.

25 min read· FlashAttention· Optimization· Memory· GPU

FlashAttention: Fast and Memory-Efficient Attention

FlashAttention revolutionized transformer training by making attention computation 2-4x faster and using dramatically less memory. It achieves this not by changing the algorithm, but by being IO-aware - optimizing how data moves between GPU memory hierarchy.

The Memory Bottleneck

Standard Attention is Memory-Bound

python
import torch
import torch.nn.functional as F
import math

def standard_attention(Q, K, V, mask=None):
    """
    Standard attention implementation.

    Args:
        Q, K, V: (batch, heads, seq_len, d_k)

    Returns:
        output: (batch, heads, seq_len, d_k)
    """
    d_k = Q.size(-1)

    # Step 1: Compute attention scores
    # Shape: (batch, heads, seq_len, seq_len)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 2: Apply softmax
    # Shape: (batch, heads, seq_len, seq_len)
    attn_weights = F.softmax(scores, dim=-1)

    # Step 3: Apply attention to values
    # Shape: (batch, heads, seq_len, d_k)
    output = torch.matmul(attn_weights, V)

    return output


# Analyze memory usage
batch, heads, seq_len, d_k = 8, 32, 2048, 64

Q = torch.randn(batch, heads, seq_len, d_k, device='cuda')
K = torch.randn(batch, heads, seq_len, d_k, device='cuda')
V = torch.randn(batch, heads, seq_len, d_k, device='cuda')

# Memory required for attention matrix
attn_matrix_size = batch * heads * seq_len * seq_len * 4  # 4 bytes per float32
print(f"Attention matrix memory: {attn_matrix_size / 1e9:.2f} GB")
print(f"For seq_len=2048: {attn_matrix_size / 1e9:.2f} GB")
print(f"For seq_len=4096: {batch * heads * 4096 * 4096 * 4 / 1e9:.2f} GB")
print(f"For seq_len=8192: {batch * heads * 8192 * 8192 * 4 / 1e9:.2f} GB")

The problem: Attention matrix is O(n²) in sequence length. For n=8192, a single matrix can be 32GB!

GPU Memory Hierarchy:

  • HBM (High Bandwidth Memory): 40-80 GB, slow (~1.5 TB/s)
  • SRAM (On-chip memory): ~20 MB, fast (~19 TB/s)

Standard attention stores the full attention matrix in HBM. Reading/writing this matrix is the bottleneck, not the computation itself.

FlashAttention's Key Insight

IO-Aware Algorithm Design

Instead of:

  1. Compute full attention matrix in HBM
  2. Apply softmax (requires full matrix)
  3. Multiply with values

FlashAttention:

  1. Tile the computation into blocks that fit in SRAM
  2. Recompute attention scores on the fly (instead of storing)
  3. Fuse operations to minimize memory reads/writes

The Tradeoff:

FlashAttention does more computation (recomputes scores in backward pass) but far less IO. Since modern GPUs are IO-bound, this is a huge win.

  • Compute: O(n² d) → same
  • Memory: O(n²) → O(n) ✓
  • IO: O(n²) → O(n² d² M⁻¹) where M = SRAM size ✓

Tiling Strategy

Block-Sparse Computation

python
def tiled_attention_conceptual(Q, K, V, block_size=64):
    """
    Conceptual tiled attention (simplified).

    This is NOT the actual FlashAttention algorithm, but shows the idea
    of processing in blocks.

    Args:
        Q, K, V: (seq_len, d_k)
        block_size: Size of blocks to process

    Returns:
        output: (seq_len, d_k)
    """
    seq_len, d_k = Q.shape
    output = torch.zeros_like(Q)

    # Process in blocks
    for i in range(0, seq_len, block_size):
        for j in range(0, seq_len, block_size):
            # Get blocks
            Q_block = Q[i:i+block_size]  # (block_size, d_k)
            K_block = K[j:j+block_size]  # (block_size, d_k)
            V_block = V[j:j+block_size]  # (block_size, d_k)

            # Compute attention for this block
            scores = torch.matmul(Q_block, K_block.T) / math.sqrt(d_k)
            # scores: (block_size, block_size) - fits in SRAM!

            attn = F.softmax(scores, dim=-1)
            output_block = torch.matmul(attn, V_block)

            # Accumulate (in reality, needs careful normalization)
            output[i:i+block_size] += output_block

    return output


# Show memory savings
seq_len = 2048
d_k = 64
block_size = 64

# Standard attention
full_matrix_memory = seq_len * seq_len * 4
print(f"Standard attention matrix: {full_matrix_memory / 1e6:.2f} MB")

# Tiled attention
block_matrix_memory = block_size * block_size * 4
print(f"Tiled attention block: {block_matrix_memory / 1e3:.2f} KB")
print(f"Memory reduction: {full_matrix_memory / block_matrix_memory:.0f}x")

Online Softmax

The challenge: Softmax requires the full row to normalize. FlashAttention solves this with online softmax:

python
def online_softmax_demo():
    """
    Demonstrate online softmax computation.

    Key idea: Update running max and sum as we process blocks.
    """
    # Example: compute softmax([1, 2, 3, 4]) in two blocks

    # Block 1: [1, 2]
    block1 = torch.tensor([1.0, 2.0])
    m1 = block1.max()  # Running max
    exp1 = torch.exp(block1 - m1)
    sum1 = exp1.sum()  # Running sum

    print("After block 1:")
    print(f"  Max: {m1:.4f}")
    print(f"  Sum: {sum1:.4f}")
    print(f"  Partial softmax: {exp1 / sum1}")

    # Block 2: [3, 4]
    block2 = torch.tensor([3.0, 4.0])
    m2 = block2.max()

    # Update running statistics
    m_new = max(m1, m2)
    exp1_corrected = exp1 * torch.exp(m1 - m_new)  # Correct previous block
    exp2 = torch.exp(block2 - m_new)

    sum_new = exp1_corrected.sum() + exp2.sum()

    print("\nAfter block 2:")
    print(f"  Max: {m_new:.4f}")
    print(f"  Sum: {sum_new:.4f}")

    # Final softmax
    softmax1 = exp1_corrected / sum_new
    softmax2 = exp2 / sum_new

    print(f"\nFinal softmax:")
    print(f"  Block 1: {softmax1}")
    print(f"  Block 2: {softmax2}")

    # Verify against standard softmax
    full = torch.tensor([1.0, 2.0, 3.0, 4.0])
    standard = F.softmax(full, dim=0)
    online = torch.cat([softmax1, softmax2])

    print(f"\nStandard softmax: {standard}")
    print(f"Online softmax:   {online}")
    print(f"Match: {torch.allclose(standard, online)}")

online_softmax_demo()

Online Softmax Math:

Given partial results with max m₁ and sum s₁, when adding new block with max m₂:

m_new = max(m₁, m₂)
s_new = s₁ × exp(m₁ - m_new) + s₂ × exp(m₂ - m_new)

This allows computing softmax incrementally without storing the full attention matrix!

FlashAttention Algorithm

Forward Pass

python
def flash_attention_forward_simplified(Q, K, V, block_size=64):
    """
    Simplified FlashAttention forward pass.

    This is a conceptual implementation showing the key ideas.
    The actual implementation uses CUDA and is more optimized.

    Args:
        Q, K, V: (seq_len, d_k)
        block_size: Block size for tiling

    Returns:
        output: (seq_len, d_k)
    """
    seq_len, d_k = Q.shape
    num_blocks = (seq_len + block_size - 1) // block_size

    # Initialize output and statistics
    O = torch.zeros_like(Q)
    l = torch.zeros(seq_len)  # Running sum for softmax
    m = torch.full((seq_len,), float('-inf'))  # Running max

    # Iterate over KV blocks (outer loop)
    for j in range(num_blocks):
        # Load KV block into SRAM
        j_start = j * block_size
        j_end = min((j + 1) * block_size, seq_len)

        K_j = K[j_start:j_end]  # (block_size, d_k)
        V_j = V[j_start:j_end]  # (block_size, d_k)

        # Iterate over Q blocks (inner loop)
        for i in range(num_blocks):
            # Load Q block into SRAM
            i_start = i * block_size
            i_end = min((i + 1) * block_size, seq_len)

            Q_i = Q[i_start:i_end]  # (block_size, d_k)

            # Compute attention scores for this block
            S_ij = torch.matmul(Q_i, K_j.T) / math.sqrt(d_k)
            # S_ij: (block_size, block_size)

            # Update running statistics (online softmax)
            m_old = m[i_start:i_end].clone()
            m_new = torch.maximum(m_old, S_ij.max(dim=-1).values)

            # Compute exponentials with corrected normalization
            exp_old = torch.exp(m_old - m_new)
            exp_new = torch.exp(S_ij - m_new.unsqueeze(-1))

            # Update running sum
            l_old = l[i_start:i_end].clone()
            l_new = exp_old * l_old + exp_new.sum(dim=-1)

            # Update output
            O[i_start:i_end] = (
                (exp_old.unsqueeze(-1) * l_old.unsqueeze(-1) * O[i_start:i_end]) +
                torch.matmul(exp_new, V_j)
            ) / l_new.unsqueeze(-1)

            # Update statistics
            m[i_start:i_end] = m_new
            l[i_start:i_end] = l_new

    return O


# Test simplified FlashAttention
seq_len = 256
d_k = 64

Q = torch.randn(seq_len, d_k)
K = torch.randn(seq_len, d_k)
V = torch.randn(seq_len, d_k)

# Standard attention
standard_output = standard_attention(
    Q.unsqueeze(0).unsqueeze(0),
    K.unsqueeze(0).unsqueeze(0),
    V.unsqueeze(0).unsqueeze(0)
).squeeze()

# Flash attention (simplified)
flash_output = flash_attention_forward_simplified(Q, K, V, block_size=64)

# Compare
print(f"Outputs match: {torch.allclose(standard_output, flash_output, atol=1e-5)}")
print(f"Max difference: {(standard_output - flash_output).abs().max():.6f}")

Backward Pass

FlashAttention's backward pass is even more clever:

python
"""
FlashAttention Backward Pass Key Ideas:

1. **Recomputation:** Instead of storing attention matrix for backward pass,
   recompute it on the fly from saved Q, K, V.

2. **Tiled gradients:** Compute gradients in blocks, accumulating results.

3. **Online statistics:** Use online algorithms for gradient normalization.

Memory savings:
- Standard: Store O(n²) attention matrix
- FlashAttention: Store O(n) statistics, recompute attention

Tradeoff:
- More computation (recompute attention)
- Far less memory (don't store n² matrix)
- Net speedup because memory bandwidth is the bottleneck!
"""

Recomputation Trick:

In standard attention, backward pass needs the attention matrix (n² memory).

FlashAttention: Don't store attention matrix. Instead, store Q, K, V (n memory) and recompute attention blocks as needed during backward pass.

Since GPUs are memory-bound, this recomputation is faster than reading the full matrix from HBM!

Performance Benefits

Memory Scaling

python
import matplotlib.pyplot as plt
import numpy as np

# Memory scaling comparison
seq_lengths = np.array([512, 1024, 2048, 4096, 8192, 16384])
d_k = 64
batch = 8
heads = 32

# Standard attention: O(n²) for attention matrix
standard_memory = batch * heads * seq_lengths**2 * 4 / 1e9  # GB

# FlashAttention: O(n) - only stores outputs and statistics
flash_memory = batch * heads * seq_lengths * d_k * 4 / 1e9  # GB (simplified)

plt.figure(figsize=(10, 6))
plt.plot(seq_lengths, standard_memory, 'r-o', linewidth=2, label='Standard Attention')
plt.plot(seq_lengths, flash_memory, 'b-s', linewidth=2, label='FlashAttention')
plt.xlabel('Sequence Length')
plt.ylabel('Memory (GB)')
plt.title('Memory Usage: Standard vs FlashAttention')
plt.legend()
plt.grid(True, alpha=0.3)
plt.yscale('log')
plt.xscale('log')
plt.show()

# Print comparison table
print("Memory Usage Comparison (batch=8, heads=32):\n")
print(f"{'Seq Len':<10s} {'Standard':<12s} {'Flash':<12s} {'Reduction':<10s}")
print("-" * 50)
for i, n in enumerate(seq_lengths):
    std = standard_memory[i]
    flash = flash_memory[i]
    reduction = std / flash
    print(f"{n:<10d} {std:<12.2f} {flash:<12.2f} {reduction:<10.1f}x")

Speed Comparison

python
# Benchmark on actual GPU (requires flash-attn package)
def benchmark_attention():
    """
    Benchmark standard vs FlashAttention.

    Requires: pip install flash-attn
    """
    try:
        from flash_attn import flash_attn_func
        has_flash = True
    except ImportError:
        print("flash-attn not installed. Install with: pip install flash-attn")
        has_flash = False
        return

    import time

    batch, heads, seq_len, d_k = 8, 32, 2048, 64

    Q = torch.randn(batch, seq_len, heads, d_k, device='cuda', dtype=torch.float16)
    K = torch.randn(batch, seq_len, heads, d_k, device='cuda', dtype=torch.float16)
    V = torch.randn(batch, seq_len, heads, d_k, device='cuda', dtype=torch.float16)

    # Warmup
    for _ in range(10):
        _ = F.scaled_dot_product_attention(
            Q.transpose(1, 2), K.transpose(1, 2), V.transpose(1, 2)
        )

    # Benchmark standard attention
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(100):
        _ = F.scaled_dot_product_attention(
            Q.transpose(1, 2), K.transpose(1, 2), V.transpose(1, 2)
        )
    torch.cuda.synchronize()
    standard_time = time.time() - start

    # Benchmark FlashAttention
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(100):
        _ = flash_attn_func(Q, K, V)
    torch.cuda.synchronize()
    flash_time = time.time() - start

    print(f"\nBenchmark Results (seq_len={seq_len}):")
    print(f"  Standard Attention: {standard_time:.4f}s")
    print(f"  FlashAttention:     {flash_time:.4f}s")
    print(f"  Speedup:           {standard_time/flash_time:.2f}x")

# Note: This requires flash-attn to be installed
# benchmark_attention()

Typical results:

  • 2-4x speedup for training
  • 5-10x memory reduction for long sequences
  • Enables longer sequences: Can train with 4x longer context

When FlashAttention Helps Most:

  1. Long sequences: seq_len > 2048
  2. Large batch sizes: More tokens to process
  3. Training: Both forward and backward passes benefit
  4. Memory-constrained setups: Can fit larger models/longer sequences

Less beneficial for very short sequences (< 512) where the overhead of tiling dominates.

FlashAttention 2 and Beyond

FlashAttention 2 Improvements

python
"""
FlashAttention 2 (July 2023) improvements:

1. **Better parallelism:**
   - Reduced non-matmul operations
   - Better GPU utilization (up to 2x faster)

2. **Sequence length parallelism:**
   - Split seq_len across thread blocks
   - Better scaling on A100/H100

3. **Lower memory:**
   - Optimized recomputation strategy
   - Reduced SRAM usage

Results:
- 2x faster than FlashAttention 1
- 4-8x faster than standard attention
- Reaches 73% of theoretical max FLOPS on A100
"""

Integration in Models

FlashAttention is now standard in modern LLMs:

python
# PyTorch 2.0+ includes scaled_dot_product_attention with FlashAttention backend
def use_flash_in_pytorch():
    """
    PyTorch 2.0+ automatically uses FlashAttention when available.
    """
    Q = torch.randn(2, 8, 1024, 64, device='cuda')
    K = torch.randn(2, 8, 1024, 64, device='cuda')
    V = torch.randn(2, 8, 1024, 64, device='cuda')

    # Automatically uses FlashAttention if installed
    output = F.scaled_dot_product_attention(Q, K, V)

    return output

# Models using FlashAttention:
flash_users = [
    "LLaMA 2",
    "Mistral",
    "Mixtral",
    "MPT",
    "Falcon",
    "GPT-NeoX",
    "Many others via PyTorch/HuggingFace"
]

print("Models using FlashAttention:")
for model in flash_users:
    print(f"  - {model}")

Summary

FlashAttention revolutionized transformer efficiency through IO-aware algorithm design:

Key Ideas:

  1. Tiling: Process attention in blocks that fit in fast SRAM
  2. Recomputation: Recompute attention in backward pass instead of storing
  3. Online algorithms: Compute softmax incrementally without full matrix
  4. Kernel fusion: Fuse operations to minimize memory transfers

Benefits:

  • 2-4x faster training (FlashAttention 2: up to 8x)
  • 5-10x less memory for long sequences
  • Longer sequences: Enable 4x longer context windows
  • Better GPU utilization: Reach 70%+ of theoretical FLOPS

Impact:

  • Enabled models with 32k+ context (MPT-StoryWriter, Claude)
  • Reduced training costs
  • Now standard in PyTorch 2.0+ and most modern LLMs

FlashAttention shows that algorithm design for hardware can yield dramatic improvements without changing the underlying computation.