QLoRA: Quantized Low-Rank Adaptation
QLoRA (Quantized LoRA) enables fine-tuning of 65B+ parameter models on a single consumer GPU by combining 4-bit quantization with LoRA. This breakthrough democratized LLM fine-tuning.
The Problem
Even with LoRA's efficiency, fine-tuning large models faces a memory bottleneck:
import torch
def calculate_memory_requirements(num_params_billions, dtype='float16'):
"""
Calculate memory needed for model fine-tuning.
Args:
num_params_billions: Model size in billions
dtype: Data type (float32, float16, etc.)
Returns:
Memory requirements in GB
"""
bytes_per_param = {
'float32': 4,
'float16': 2,
'bfloat16': 2,
'int8': 1,
'int4': 0.5,
'nf4': 0.5
}
params = num_params_billions * 1e9
bytes_pp = bytes_per_param[dtype]
# Components:
# 1. Model weights
# 2. Gradients (same size as weights, but only for trainable params in LoRA)
# 3. Optimizer states (2x for Adam: momentum + variance)
# 4. Activations (varies, ~20% of model size)
model_memory = params * bytes_pp
gradient_memory = params * bytes_pp * 0.01 # LoRA: ~1% trainable
optimizer_memory = gradient_memory * 2 # Adam states
activation_memory = model_memory * 0.2
total_memory = (model_memory + gradient_memory + optimizer_memory + activation_memory) / 1e9
print(f"\n{num_params_billions}B parameter model with {dtype}:")
print(f" Model weights: {model_memory/1e9:.2f} GB")
print(f" Gradients (LoRA): {gradient_memory/1e9:.2f} GB")
print(f" Optimizer states: {optimizer_memory/1e9:.2f} GB")
print(f" Activations: {activation_memory/1e9:.2f} GB")
print(f" Total: {total_memory:.2f} GB")
return total_memory
# Examples
print("Standard LoRA fine-tuning:")
calculate_memory_requirements(7, 'float16')
calculate_memory_requirements(13, 'float16')
calculate_memory_requirements(65, 'float16')
print("\n\nQLoRA fine-tuning (4-bit):")
calculate_memory_requirements(7, 'nf4')
calculate_memory_requirements(13, 'nf4')
calculate_memory_requirements(65, 'nf4')
Memory Comparison:
LLaMA 65B with standard LoRA (FP16):
- Model: 130 GB
- Total with training: ~156 GB
- Requires: Multiple high-end GPUs
LLaMA 65B with QLoRA (4-bit NF4):
- Model: 32.5 GB
- Total with training: ~39 GB
- Requires: Single consumer GPU (A6000, 3090)
QLoRA makes 65B+ model fine-tuning accessible!
QLoRA Components
QLoRA introduces three key innovations:
- 4-bit NormalFloat (NF4): Information-theoretically optimal 4-bit quantization
- Double Quantization: Quantize the quantization constants
- Paged Optimizers: Efficient memory management
1. NormalFloat (NF4) Quantization
Standard quantization (uniform bins) wastes precision for normally-distributed weights.
NF4 uses bins optimized for a normal distribution:
import torch
import numpy as np
from scipy import stats
class NF4Quantizer:
"""
NormalFloat 4-bit quantization.
Optimized for normally distributed weights.
"""
def __init__(self):
"""
Initialize NF4 quantization bins.
NF4 uses 16 bins (4 bits) positioned to minimize quantization error
for a standard normal distribution N(0,1).
"""
# Pre-computed optimal NF4 bins for N(0,1)
# These are quantiles that minimize expected quantization error
self.nf4_bins = torch.tensor([
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0
])
# Compute midpoints for quantization
self.compute_quantization_map()
def compute_quantization_map(self):
"""Compute midpoints between bins for quantization."""
bins = self.nf4_bins
self.midpoints = (bins[:-1] + bins[1:]) / 2
def quantize(self, weights):
"""
Quantize weights to 4-bit NF4.
Args:
weights: Weight tensor (any shape)
Returns:
Quantized weights (uint8), scale factor, offset
"""
# Normalize weights to N(0,1) distribution
mean = weights.mean()
std = weights.std()
normalized_weights = (weights - mean) / (std + 1e-8)
# Clip to [-1, 1] range (covers >99.7% of normal distribution)
normalized_weights = torch.clamp(normalized_weights, -1.0, 1.0)
# Find nearest bin for each weight
# Use broadcasting to compute distances to all bins
distances = torch.abs(
normalized_weights.unsqueeze(-1) - self.nf4_bins.to(weights.device)
)
quantized_indices = torch.argmin(distances, dim=-1)
# Store as uint8 (only lower 4 bits used)
quantized_weights = quantized_indices.to(torch.uint8)
# Return quantized weights + normalization params
return quantized_weights, mean, std
def dequantize(self, quantized_weights, mean, std):
"""
Dequantize NF4 weights back to float.
Args:
quantized_weights: Quantized indices (uint8)
mean: Original mean
std: Original std
Returns:
Dequantized weights
"""
# Map indices to NF4 values
nf4_values = self.nf4_bins.to(quantized_weights.device)[quantized_weights]
# Denormalize
dequantized = nf4_values * std + mean
return dequantized
# Test NF4 quantization
quantizer = NF4Quantizer()
# Simulate normally distributed weights
weights = torch.randn(1000, 768) * 0.02 # Typical LLM weight distribution
print(f"Original weights: mean={weights.mean():.6f}, std={weights.std():.6f}")
print(f"Original size: {weights.numel() * 4 / 1e6:.2f} MB (float32)")
# Quantize
quantized, mean, std = quantizer.quantize(weights)
print(f"\nQuantized size: {quantized.numel() * 0.5 / 1e6:.2f} MB (4-bit)")
print(f"Compression ratio: {8.0:.1f}x")
# Dequantize
dequantized = quantizer.dequantize(quantized, mean, std)
# Measure error
mse = torch.mean((weights - dequantized) ** 2)
relative_error = mse / torch.var(weights)
print(f"\nQuantization error:")
print(f" MSE: {mse:.8f}")
print(f" Relative error: {relative_error:.4f}")
# Compare with uniform quantization
uniform_scale = weights.max() / 15 # 4-bit: 0-15
uniform_quantized = torch.round(weights / uniform_scale).clamp(0, 15)
uniform_dequantized = uniform_quantized * uniform_scale
uniform_mse = torch.mean((weights - uniform_dequantized) ** 2)
print(f"\nComparison with uniform quantization:")
print(f" NF4 MSE: {mse:.8f}")
print(f" Uniform MSE: {uniform_mse:.8f}")
print(f" NF4 improvement: {uniform_mse/mse:.2f}x better")
Why NF4 Works Better:
Neural network weights typically follow a normal distribution. NF4 places more quantization bins near zero (where most weights are) and fewer bins at extremes.
Uniform quantization: Equal spacing, wastes precision NF4 quantization: Denser bins near zero, optimal for normal distributions
Result: ~2x lower quantization error for same bit width!
2. Double Quantization
QLoRA quantizes the quantization constants themselves to save memory:
class DoubleQuantization:
"""
Double quantization: quantize the quantization constants.
First quantization: Weights -> 4-bit + FP32 constants (per block)
Second quantization: FP32 constants -> 8-bit + global FP32 constant
"""
def __init__(self, block_size=64):
"""
Args:
block_size: Size of blocks for first quantization
"""
self.block_size = block_size
self.nf4 = NF4Quantizer()
def quantize(self, weights):
"""
Perform double quantization.
Args:
weights: Weight tensor (2D)
Returns:
Doubly quantized representation
"""
# Reshape into blocks
original_shape = weights.shape
flat_weights = weights.flatten()
# Pad to multiple of block_size
padding = (self.block_size - flat_weights.numel() % self.block_size) % self.block_size
if padding > 0:
flat_weights = torch.cat([flat_weights, torch.zeros(padding)])
blocks = flat_weights.reshape(-1, self.block_size)
num_blocks = blocks.shape[0]
# First quantization: each block -> 4-bit + constants
quantized_blocks = []
block_means = []
block_stds = []
for block in blocks:
q_block, mean, std = self.nf4.quantize(block)
quantized_blocks.append(q_block)
block_means.append(mean)
block_stds.append(std)
# Stack quantized blocks
quantized_data = torch.stack(quantized_blocks)
# Convert constants to tensors
block_means = torch.tensor(block_means)
block_stds = torch.tensor(block_stds)
# Second quantization: quantize the constants
# Use 8-bit uniform quantization for constants
mean_min, mean_max = block_means.min(), block_means.max()
std_min, std_max = block_stds.min(), block_stds.max()
# Quantize means to 8-bit
mean_scale = (mean_max - mean_min) / 255
quantized_means = torch.round((block_means - mean_min) / mean_scale).to(torch.uint8)
# Quantize stds to 8-bit
std_scale = (std_max - std_min) / 255
quantized_stds = torch.round((block_stds - std_min) / std_scale).to(torch.uint8)
return {
'quantized_weights': quantized_data, # 4-bit per element
'quantized_means': quantized_means, # 8-bit per block
'quantized_stds': quantized_stds, # 8-bit per block
'mean_offset': mean_min, # FP32 (single value)
'mean_scale': mean_scale, # FP32 (single value)
'std_offset': std_min, # FP32 (single value)
'std_scale': std_scale, # FP32 (single value)
'original_shape': original_shape,
'padding': padding
}
def dequantize(self, quantized_dict):
"""Dequantize back to original weights."""
# Dequantize constants
means = quantized_dict['quantized_means'].float() * quantized_dict['mean_scale'] + quantized_dict['mean_offset']
stds = quantized_dict['quantized_stds'].float() * quantized_dict['std_scale'] + quantized_dict['std_offset']
# Dequantize blocks
dequantized_blocks = []
for i, q_block in enumerate(quantized_dict['quantized_weights']):
dequantized_block = self.nf4.dequantize(q_block, means[i], stds[i])
dequantized_blocks.append(dequantized_block)
# Concatenate blocks
dequantized_flat = torch.cat(dequantized_blocks)
# Remove padding
if quantized_dict['padding'] > 0:
dequantized_flat = dequantized_flat[:-quantized_dict['padding']]
# Reshape to original
dequantized = dequantized_flat.reshape(quantized_dict['original_shape'])
return dequantized
def compute_memory_saved(self, weights):
"""Calculate memory savings from double quantization."""
# Original: FP32
original_memory = weights.numel() * 4
# Single quantization: 4-bit + FP32 constants per block
num_blocks = (weights.numel() + self.block_size - 1) // self.block_size
single_quant_memory = weights.numel() * 0.5 + num_blocks * 8 # 2 FP32 per block
# Double quantization: 4-bit + 8-bit constants + 4 global FP32
double_quant_memory = weights.numel() * 0.5 + num_blocks * 2 + 16 # 2 uint8 per block + 4 FP32
print(f"\nMemory comparison for {weights.numel():,} parameters:")
print(f" Original (FP32): {original_memory/1e6:.2f} MB")
print(f" Single quantization: {single_quant_memory/1e6:.2f} MB")
print(f" Double quantization: {double_quant_memory/1e6:.2f} MB")
print(f" Savings vs single: {(single_quant_memory - double_quant_memory)/1e6:.2f} MB")
# Test double quantization
dq = DoubleQuantization(block_size=64)
weights = torch.randn(4096, 4096) * 0.02
quantized_dict = dq.quantize(weights)
dequantized = dq.dequantize(quantized_dict)
error = torch.mean((weights - dequantized) ** 2)
print(f"Double quantization error: {error:.8f}")
dq.compute_memory_saved(weights)
Double Quantization Savings:
For a 65B parameter model:
- Single quantization: 32.5 GB (4-bit weights) + 2 GB (FP32 constants) = 34.5 GB
- Double quantization: 32.5 GB (4-bit weights) + 0.5 GB (8-bit constants) = 33 GB
Savings: ~1.5 GB (4.3% additional reduction)
May seem small, but crucial for fitting large models on consumer GPUs!
3. Complete QLoRA Implementation
class QLoRALinear(torch.nn.Module):
"""
Linear layer with QLoRA: 4-bit quantized base weights + LoRA adapters.
"""
def __init__(
self,
in_features,
out_features,
rank=8,
alpha=16,
dropout=0.1,
quantize_base=True
):
"""
Args:
in_features: Input dimension
out_features: Output dimension
rank: LoRA rank
alpha: LoRA alpha
dropout: Dropout
quantize_base: Whether to quantize base weights
"""
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.scaling = alpha / rank
# Base weights (quantized if enabled)
if quantize_base:
self.quantizer = DoubleQuantization(block_size=64)
# Initialize with random weights, then quantize
base_weights = torch.randn(out_features, in_features) * 0.02
self.quantized_weight = self.quantizer.quantize(base_weights)
# Don't store as parameter (not trainable)
self.register_buffer('_base_weight_quantized', torch.tensor(0)) # Placeholder
else:
self.base_weight = torch.nn.Parameter(
torch.randn(out_features, in_features) * 0.02
)
self.base_weight.requires_grad = False
# LoRA adapters (always FP16/FP32)
self.lora_A = torch.nn.Parameter(torch.randn(rank, in_features))
self.lora_B = torch.nn.Parameter(torch.zeros(out_features, rank))
self.dropout = torch.nn.Dropout(p=dropout)
# Initialize LoRA A
torch.nn.init.kaiming_uniform_(self.lora_A, a=np.sqrt(5))
self.quantize_base = quantize_base
def get_base_weight(self):
"""Get dequantized base weight."""
if self.quantize_base:
return self.quantizer.dequantize(self.quantized_weight)
else:
return self.base_weight
def forward(self, x):
"""
Forward pass: y = Wx + (α/r)BAx
Base weight W is quantized (4-bit), LoRA is full precision.
Args:
x: Input tensor (..., in_features)
Returns:
Output tensor (..., out_features)
"""
# Base weight computation (dequantized on-the-fly)
base_weight = self.get_base_weight()
base_output = torch.nn.functional.linear(x, base_weight)
# LoRA computation
x_dropout = self.dropout(x)
lora_output = (x_dropout @ self.lora_A.T) @ self.lora_B.T
lora_output = lora_output * self.scaling
return base_output + lora_output
# Test QLoRA layer
print("Testing QLoRA Linear Layer:")
# Create layers
standard_linear = torch.nn.Linear(4096, 4096)
qlora_linear = QLoRALinear(4096, 4096, rank=8, quantize_base=True)
# Count memory
def get_model_memory(model):
"""Estimate model memory usage."""
total_mem = 0
for param in model.parameters():
total_mem += param.numel() * param.element_size()
return total_mem
standard_mem = get_model_memory(standard_linear)
qlora_mem = get_model_memory(qlora_linear)
print(f"\nStandard Linear: {standard_mem/1e6:.2f} MB")
print(f"QLoRA Linear: {qlora_mem/1e6:.2f} MB")
print(f"Memory reduction: {standard_mem/qlora_mem:.1f}x")
# Test forward pass
x = torch.randn(2, 10, 4096)
output = qlora_linear(x)
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")
Training with QLoRA
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
def create_qlora_model(model_name, lora_rank=8, lora_alpha=16):
"""
Create a model with QLoRA using HuggingFace's bitsandbytes.
Args:
model_name: Model name
lora_rank: LoRA rank
lora_alpha: LoRA alpha
Returns:
Model with QLoRA
"""
# 4-bit quantization config
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # Use NormalFloat 4-bit
bnb_4bit_use_double_quant=True, # Double quantization
bnb_4bit_compute_dtype=torch.bfloat16 # Compute in bfloat16
)
# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto", # Automatic device placement
trust_remote_code=True
)
# Add LoRA adapters using PEFT library
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)
# Configure LoRA
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=["q_proj", "v_proj"], # Apply to Q and V
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# Add LoRA
model = get_peft_model(model, lora_config)
# Print trainable parameters
model.print_trainable_parameters()
return model
# Example usage (commented out - requires actual model)
# model = create_qlora_model("meta-llama/Llama-2-7b-hf", lora_rank=16)
# Training would use standard PyTorch/HuggingFace training loop
# Only LoRA parameters are updated; quantized base stays frozen
QLoRA Training Considerations:
- Slower than FP16: Dequantization adds ~20-30% overhead
- Requires bitsandbytes: Install with
pip install bitsandbytes - GPU compatibility: Requires modern GPUs (Ampere/Ada/Hopper)
- Slight accuracy loss: 4-bit causes minor degradation (<1% typically)
- Memory-compute tradeoff: Saves memory but slightly slower
Worth it for enabling fine-tuning that wouldn't otherwise fit!
QLoRA vs Standard LoRA
import pandas as pd
comparison = pd.DataFrame({
'Aspect': [
'Base weights',
'LoRA adapters',
'Memory (7B model)',
'Memory (65B model)',
'Training speed',
'Accuracy',
'GPU required (65B)',
'Typical use case'
],
'Standard LoRA': [
'FP16/BF16',
'FP16/BF16',
'~14 GB',
'~130 GB',
'Fast',
'Best',
'Multiple A100 80GB',
'Smaller models (<13B)'
],
'QLoRA': [
'NF4 (4-bit)',
'FP16/BF16',
'~5 GB',
'~33 GB',
'Slower (~1.3x)',
'Near-best (< 1% drop)',
'Single A6000/3090',
'Large models (65B+)'
]
})
print(comparison.to_string(index=False))
Summary
QLoRA democratizes LLM fine-tuning through three innovations:
- NF4 Quantization: 4-bit quantization optimized for normal distributions
- Double Quantization: Quantize the quantization constants to save more memory
- Paged Optimizers: Efficient memory management for large models
Result: Fine-tune 65B models on consumer GPUs with minimal accuracy loss.
QLoRA opened fine-tuning to individual researchers and small companies, accelerating LLM research and applications.