FlashAttention: Fast and Memory-Efficient Attention
FlashAttention revolutionized transformer training by making attention computation 2-4x faster and using dramatically less memory. It achieves this not by changing the algorithm, but by being IO-aware - optimizing how data moves between GPU memory hierarchy.
The Memory Bottleneck
Standard Attention is Memory-Bound
import torch
import torch.nn.functional as F
import math
def standard_attention(Q, K, V, mask=None):
"""
Standard attention implementation.
Args:
Q, K, V: (batch, heads, seq_len, d_k)
Returns:
output: (batch, heads, seq_len, d_k)
"""
d_k = Q.size(-1)
# Step 1: Compute attention scores
# Shape: (batch, heads, seq_len, seq_len)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 2: Apply softmax
# Shape: (batch, heads, seq_len, seq_len)
attn_weights = F.softmax(scores, dim=-1)
# Step 3: Apply attention to values
# Shape: (batch, heads, seq_len, d_k)
output = torch.matmul(attn_weights, V)
return output
# Analyze memory usage
batch, heads, seq_len, d_k = 8, 32, 2048, 64
Q = torch.randn(batch, heads, seq_len, d_k, device='cuda')
K = torch.randn(batch, heads, seq_len, d_k, device='cuda')
V = torch.randn(batch, heads, seq_len, d_k, device='cuda')
# Memory required for attention matrix
attn_matrix_size = batch * heads * seq_len * seq_len * 4 # 4 bytes per float32
print(f"Attention matrix memory: {attn_matrix_size / 1e9:.2f} GB")
print(f"For seq_len=2048: {attn_matrix_size / 1e9:.2f} GB")
print(f"For seq_len=4096: {batch * heads * 4096 * 4096 * 4 / 1e9:.2f} GB")
print(f"For seq_len=8192: {batch * heads * 8192 * 8192 * 4 / 1e9:.2f} GB")
The problem: Attention matrix is O(n²) in sequence length. For n=8192, a single matrix can be 32GB!
GPU Memory Hierarchy:
- HBM (High Bandwidth Memory): 40-80 GB, slow (~1.5 TB/s)
- SRAM (On-chip memory): ~20 MB, fast (~19 TB/s)
Standard attention stores the full attention matrix in HBM. Reading/writing this matrix is the bottleneck, not the computation itself.
FlashAttention's Key Insight
IO-Aware Algorithm Design
Instead of:
- Compute full attention matrix in HBM
- Apply softmax (requires full matrix)
- Multiply with values
FlashAttention:
- Tile the computation into blocks that fit in SRAM
- Recompute attention scores on the fly (instead of storing)
- Fuse operations to minimize memory reads/writes
The Tradeoff:
FlashAttention does more computation (recomputes scores in backward pass) but far less IO. Since modern GPUs are IO-bound, this is a huge win.
- Compute: O(n² d) → same
- Memory: O(n²) → O(n) ✓
- IO: O(n²) → O(n² d² M⁻¹) where M = SRAM size ✓
Tiling Strategy
Block-Sparse Computation
def tiled_attention_conceptual(Q, K, V, block_size=64):
"""
Conceptual tiled attention (simplified).
This is NOT the actual FlashAttention algorithm, but shows the idea
of processing in blocks.
Args:
Q, K, V: (seq_len, d_k)
block_size: Size of blocks to process
Returns:
output: (seq_len, d_k)
"""
seq_len, d_k = Q.shape
output = torch.zeros_like(Q)
# Process in blocks
for i in range(0, seq_len, block_size):
for j in range(0, seq_len, block_size):
# Get blocks
Q_block = Q[i:i+block_size] # (block_size, d_k)
K_block = K[j:j+block_size] # (block_size, d_k)
V_block = V[j:j+block_size] # (block_size, d_k)
# Compute attention for this block
scores = torch.matmul(Q_block, K_block.T) / math.sqrt(d_k)
# scores: (block_size, block_size) - fits in SRAM!
attn = F.softmax(scores, dim=-1)
output_block = torch.matmul(attn, V_block)
# Accumulate (in reality, needs careful normalization)
output[i:i+block_size] += output_block
return output
# Show memory savings
seq_len = 2048
d_k = 64
block_size = 64
# Standard attention
full_matrix_memory = seq_len * seq_len * 4
print(f"Standard attention matrix: {full_matrix_memory / 1e6:.2f} MB")
# Tiled attention
block_matrix_memory = block_size * block_size * 4
print(f"Tiled attention block: {block_matrix_memory / 1e3:.2f} KB")
print(f"Memory reduction: {full_matrix_memory / block_matrix_memory:.0f}x")
Online Softmax
The challenge: Softmax requires the full row to normalize. FlashAttention solves this with online softmax:
def online_softmax_demo():
"""
Demonstrate online softmax computation.
Key idea: Update running max and sum as we process blocks.
"""
# Example: compute softmax([1, 2, 3, 4]) in two blocks
# Block 1: [1, 2]
block1 = torch.tensor([1.0, 2.0])
m1 = block1.max() # Running max
exp1 = torch.exp(block1 - m1)
sum1 = exp1.sum() # Running sum
print("After block 1:")
print(f" Max: {m1:.4f}")
print(f" Sum: {sum1:.4f}")
print(f" Partial softmax: {exp1 / sum1}")
# Block 2: [3, 4]
block2 = torch.tensor([3.0, 4.0])
m2 = block2.max()
# Update running statistics
m_new = max(m1, m2)
exp1_corrected = exp1 * torch.exp(m1 - m_new) # Correct previous block
exp2 = torch.exp(block2 - m_new)
sum_new = exp1_corrected.sum() + exp2.sum()
print("\nAfter block 2:")
print(f" Max: {m_new:.4f}")
print(f" Sum: {sum_new:.4f}")
# Final softmax
softmax1 = exp1_corrected / sum_new
softmax2 = exp2 / sum_new
print(f"\nFinal softmax:")
print(f" Block 1: {softmax1}")
print(f" Block 2: {softmax2}")
# Verify against standard softmax
full = torch.tensor([1.0, 2.0, 3.0, 4.0])
standard = F.softmax(full, dim=0)
online = torch.cat([softmax1, softmax2])
print(f"\nStandard softmax: {standard}")
print(f"Online softmax: {online}")
print(f"Match: {torch.allclose(standard, online)}")
online_softmax_demo()
Online Softmax Math:
Given partial results with max m₁ and sum s₁, when adding new block with max m₂:
m_new = max(m₁, m₂)
s_new = s₁ × exp(m₁ - m_new) + s₂ × exp(m₂ - m_new)
This allows computing softmax incrementally without storing the full attention matrix!
FlashAttention Algorithm
Forward Pass
def flash_attention_forward_simplified(Q, K, V, block_size=64):
"""
Simplified FlashAttention forward pass.
This is a conceptual implementation showing the key ideas.
The actual implementation uses CUDA and is more optimized.
Args:
Q, K, V: (seq_len, d_k)
block_size: Block size for tiling
Returns:
output: (seq_len, d_k)
"""
seq_len, d_k = Q.shape
num_blocks = (seq_len + block_size - 1) // block_size
# Initialize output and statistics
O = torch.zeros_like(Q)
l = torch.zeros(seq_len) # Running sum for softmax
m = torch.full((seq_len,), float('-inf')) # Running max
# Iterate over KV blocks (outer loop)
for j in range(num_blocks):
# Load KV block into SRAM
j_start = j * block_size
j_end = min((j + 1) * block_size, seq_len)
K_j = K[j_start:j_end] # (block_size, d_k)
V_j = V[j_start:j_end] # (block_size, d_k)
# Iterate over Q blocks (inner loop)
for i in range(num_blocks):
# Load Q block into SRAM
i_start = i * block_size
i_end = min((i + 1) * block_size, seq_len)
Q_i = Q[i_start:i_end] # (block_size, d_k)
# Compute attention scores for this block
S_ij = torch.matmul(Q_i, K_j.T) / math.sqrt(d_k)
# S_ij: (block_size, block_size)
# Update running statistics (online softmax)
m_old = m[i_start:i_end].clone()
m_new = torch.maximum(m_old, S_ij.max(dim=-1).values)
# Compute exponentials with corrected normalization
exp_old = torch.exp(m_old - m_new)
exp_new = torch.exp(S_ij - m_new.unsqueeze(-1))
# Update running sum
l_old = l[i_start:i_end].clone()
l_new = exp_old * l_old + exp_new.sum(dim=-1)
# Update output
O[i_start:i_end] = (
(exp_old.unsqueeze(-1) * l_old.unsqueeze(-1) * O[i_start:i_end]) +
torch.matmul(exp_new, V_j)
) / l_new.unsqueeze(-1)
# Update statistics
m[i_start:i_end] = m_new
l[i_start:i_end] = l_new
return O
# Test simplified FlashAttention
seq_len = 256
d_k = 64
Q = torch.randn(seq_len, d_k)
K = torch.randn(seq_len, d_k)
V = torch.randn(seq_len, d_k)
# Standard attention
standard_output = standard_attention(
Q.unsqueeze(0).unsqueeze(0),
K.unsqueeze(0).unsqueeze(0),
V.unsqueeze(0).unsqueeze(0)
).squeeze()
# Flash attention (simplified)
flash_output = flash_attention_forward_simplified(Q, K, V, block_size=64)
# Compare
print(f"Outputs match: {torch.allclose(standard_output, flash_output, atol=1e-5)}")
print(f"Max difference: {(standard_output - flash_output).abs().max():.6f}")
Backward Pass
FlashAttention's backward pass is even more clever:
"""
FlashAttention Backward Pass Key Ideas:
1. **Recomputation:** Instead of storing attention matrix for backward pass,
recompute it on the fly from saved Q, K, V.
2. **Tiled gradients:** Compute gradients in blocks, accumulating results.
3. **Online statistics:** Use online algorithms for gradient normalization.
Memory savings:
- Standard: Store O(n²) attention matrix
- FlashAttention: Store O(n) statistics, recompute attention
Tradeoff:
- More computation (recompute attention)
- Far less memory (don't store n² matrix)
- Net speedup because memory bandwidth is the bottleneck!
"""
Recomputation Trick:
In standard attention, backward pass needs the attention matrix (n² memory).
FlashAttention: Don't store attention matrix. Instead, store Q, K, V (n memory) and recompute attention blocks as needed during backward pass.
Since GPUs are memory-bound, this recomputation is faster than reading the full matrix from HBM!
Performance Benefits
Memory Scaling
import matplotlib.pyplot as plt
import numpy as np
# Memory scaling comparison
seq_lengths = np.array([512, 1024, 2048, 4096, 8192, 16384])
d_k = 64
batch = 8
heads = 32
# Standard attention: O(n²) for attention matrix
standard_memory = batch * heads * seq_lengths**2 * 4 / 1e9 # GB
# FlashAttention: O(n) - only stores outputs and statistics
flash_memory = batch * heads * seq_lengths * d_k * 4 / 1e9 # GB (simplified)
plt.figure(figsize=(10, 6))
plt.plot(seq_lengths, standard_memory, 'r-o', linewidth=2, label='Standard Attention')
plt.plot(seq_lengths, flash_memory, 'b-s', linewidth=2, label='FlashAttention')
plt.xlabel('Sequence Length')
plt.ylabel('Memory (GB)')
plt.title('Memory Usage: Standard vs FlashAttention')
plt.legend()
plt.grid(True, alpha=0.3)
plt.yscale('log')
plt.xscale('log')
plt.show()
# Print comparison table
print("Memory Usage Comparison (batch=8, heads=32):\n")
print(f"{'Seq Len':<10s} {'Standard':<12s} {'Flash':<12s} {'Reduction':<10s}")
print("-" * 50)
for i, n in enumerate(seq_lengths):
std = standard_memory[i]
flash = flash_memory[i]
reduction = std / flash
print(f"{n:<10d} {std:<12.2f} {flash:<12.2f} {reduction:<10.1f}x")
Speed Comparison
# Benchmark on actual GPU (requires flash-attn package)
def benchmark_attention():
"""
Benchmark standard vs FlashAttention.
Requires: pip install flash-attn
"""
try:
from flash_attn import flash_attn_func
has_flash = True
except ImportError:
print("flash-attn not installed. Install with: pip install flash-attn")
has_flash = False
return
import time
batch, heads, seq_len, d_k = 8, 32, 2048, 64
Q = torch.randn(batch, seq_len, heads, d_k, device='cuda', dtype=torch.float16)
K = torch.randn(batch, seq_len, heads, d_k, device='cuda', dtype=torch.float16)
V = torch.randn(batch, seq_len, heads, d_k, device='cuda', dtype=torch.float16)
# Warmup
for _ in range(10):
_ = F.scaled_dot_product_attention(
Q.transpose(1, 2), K.transpose(1, 2), V.transpose(1, 2)
)
# Benchmark standard attention
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
_ = F.scaled_dot_product_attention(
Q.transpose(1, 2), K.transpose(1, 2), V.transpose(1, 2)
)
torch.cuda.synchronize()
standard_time = time.time() - start
# Benchmark FlashAttention
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
_ = flash_attn_func(Q, K, V)
torch.cuda.synchronize()
flash_time = time.time() - start
print(f"\nBenchmark Results (seq_len={seq_len}):")
print(f" Standard Attention: {standard_time:.4f}s")
print(f" FlashAttention: {flash_time:.4f}s")
print(f" Speedup: {standard_time/flash_time:.2f}x")
# Note: This requires flash-attn to be installed
# benchmark_attention()
Typical results:
- 2-4x speedup for training
- 5-10x memory reduction for long sequences
- Enables longer sequences: Can train with 4x longer context
When FlashAttention Helps Most:
- Long sequences: seq_len > 2048
- Large batch sizes: More tokens to process
- Training: Both forward and backward passes benefit
- Memory-constrained setups: Can fit larger models/longer sequences
Less beneficial for very short sequences (< 512) where the overhead of tiling dominates.
FlashAttention 2 and Beyond
FlashAttention 2 Improvements
"""
FlashAttention 2 (July 2023) improvements:
1. **Better parallelism:**
- Reduced non-matmul operations
- Better GPU utilization (up to 2x faster)
2. **Sequence length parallelism:**
- Split seq_len across thread blocks
- Better scaling on A100/H100
3. **Lower memory:**
- Optimized recomputation strategy
- Reduced SRAM usage
Results:
- 2x faster than FlashAttention 1
- 4-8x faster than standard attention
- Reaches 73% of theoretical max FLOPS on A100
"""
Integration in Models
FlashAttention is now standard in modern LLMs:
# PyTorch 2.0+ includes scaled_dot_product_attention with FlashAttention backend
def use_flash_in_pytorch():
"""
PyTorch 2.0+ automatically uses FlashAttention when available.
"""
Q = torch.randn(2, 8, 1024, 64, device='cuda')
K = torch.randn(2, 8, 1024, 64, device='cuda')
V = torch.randn(2, 8, 1024, 64, device='cuda')
# Automatically uses FlashAttention if installed
output = F.scaled_dot_product_attention(Q, K, V)
return output
# Models using FlashAttention:
flash_users = [
"LLaMA 2",
"Mistral",
"Mixtral",
"MPT",
"Falcon",
"GPT-NeoX",
"Many others via PyTorch/HuggingFace"
]
print("Models using FlashAttention:")
for model in flash_users:
print(f" - {model}")
Summary
FlashAttention revolutionized transformer efficiency through IO-aware algorithm design:
Key Ideas:
- Tiling: Process attention in blocks that fit in fast SRAM
- Recomputation: Recompute attention in backward pass instead of storing
- Online algorithms: Compute softmax incrementally without full matrix
- Kernel fusion: Fuse operations to minimize memory transfers
Benefits:
- 2-4x faster training (FlashAttention 2: up to 8x)
- 5-10x less memory for long sequences
- Longer sequences: Enable 4x longer context windows
- Better GPU utilization: Reach 70%+ of theoretical FLOPS
Impact:
- Enabled models with 32k+ context (MPT-StoryWriter, Claude)
- Reduced training costs
- Now standard in PyTorch 2.0+ and most modern LLMs
FlashAttention shows that algorithm design for hardware can yield dramatic improvements without changing the underlying computation.