Back
advanced
Fine-Tuning Techniques

QLoRA: Quantized Low-Rank Adaptation

QLoRA combines 4-bit quantization with LoRA to fine-tune massive models on consumer hardware. Learn quantization, NF4 format, and complete implementation.

25 min read· QLoRA· Quantization· LoRA· PEFT

QLoRA: Quantized Low-Rank Adaptation

QLoRA (Quantized LoRA) enables fine-tuning of 65B+ parameter models on a single consumer GPU by combining 4-bit quantization with LoRA. This breakthrough democratized LLM fine-tuning.

The Problem

Even with LoRA's efficiency, fine-tuning large models faces a memory bottleneck:

python
import torch

def calculate_memory_requirements(num_params_billions, dtype='float16'):
    """
    Calculate memory needed for model fine-tuning.

    Args:
        num_params_billions: Model size in billions
        dtype: Data type (float32, float16, etc.)

    Returns:
        Memory requirements in GB
    """
    bytes_per_param = {
        'float32': 4,
        'float16': 2,
        'bfloat16': 2,
        'int8': 1,
        'int4': 0.5,
        'nf4': 0.5
    }

    params = num_params_billions * 1e9
    bytes_pp = bytes_per_param[dtype]

    # Components:
    # 1. Model weights
    # 2. Gradients (same size as weights, but only for trainable params in LoRA)
    # 3. Optimizer states (2x for Adam: momentum + variance)
    # 4. Activations (varies, ~20% of model size)

    model_memory = params * bytes_pp
    gradient_memory = params * bytes_pp * 0.01  # LoRA: ~1% trainable
    optimizer_memory = gradient_memory * 2  # Adam states
    activation_memory = model_memory * 0.2

    total_memory = (model_memory + gradient_memory + optimizer_memory + activation_memory) / 1e9

    print(f"\n{num_params_billions}B parameter model with {dtype}:")
    print(f"  Model weights: {model_memory/1e9:.2f} GB")
    print(f"  Gradients (LoRA): {gradient_memory/1e9:.2f} GB")
    print(f"  Optimizer states: {optimizer_memory/1e9:.2f} GB")
    print(f"  Activations: {activation_memory/1e9:.2f} GB")
    print(f"  Total: {total_memory:.2f} GB")

    return total_memory


# Examples
print("Standard LoRA fine-tuning:")
calculate_memory_requirements(7, 'float16')
calculate_memory_requirements(13, 'float16')
calculate_memory_requirements(65, 'float16')

print("\n\nQLoRA fine-tuning (4-bit):")
calculate_memory_requirements(7, 'nf4')
calculate_memory_requirements(13, 'nf4')
calculate_memory_requirements(65, 'nf4')

Memory Comparison:

LLaMA 65B with standard LoRA (FP16):

  • Model: 130 GB
  • Total with training: ~156 GB
  • Requires: Multiple high-end GPUs

LLaMA 65B with QLoRA (4-bit NF4):

  • Model: 32.5 GB
  • Total with training: ~39 GB
  • Requires: Single consumer GPU (A6000, 3090)

QLoRA makes 65B+ model fine-tuning accessible!

QLoRA Components

QLoRA introduces three key innovations:

  1. 4-bit NormalFloat (NF4): Information-theoretically optimal 4-bit quantization
  2. Double Quantization: Quantize the quantization constants
  3. Paged Optimizers: Efficient memory management

1. NormalFloat (NF4) Quantization

Standard quantization (uniform bins) wastes precision for normally-distributed weights.

NF4 uses bins optimized for a normal distribution:

python
import torch
import numpy as np
from scipy import stats

class NF4Quantizer:
    """
    NormalFloat 4-bit quantization.

    Optimized for normally distributed weights.
    """

    def __init__(self):
        """
        Initialize NF4 quantization bins.

        NF4 uses 16 bins (4 bits) positioned to minimize quantization error
        for a standard normal distribution N(0,1).
        """
        # Pre-computed optimal NF4 bins for N(0,1)
        # These are quantiles that minimize expected quantization error
        self.nf4_bins = torch.tensor([
            -1.0,
            -0.6961928009986877,
            -0.5250730514526367,
            -0.39491748809814453,
            -0.28444138169288635,
            -0.18477343022823334,
            -0.09105003625154495,
            0.0,
            0.07958029955625534,
            0.16093020141124725,
            0.24611230194568634,
            0.33791524171829224,
            0.44070982933044434,
            0.5626170039176941,
            0.7229568362236023,
            1.0
        ])

        # Compute midpoints for quantization
        self.compute_quantization_map()

    def compute_quantization_map(self):
        """Compute midpoints between bins for quantization."""
        bins = self.nf4_bins
        self.midpoints = (bins[:-1] + bins[1:]) / 2

    def quantize(self, weights):
        """
        Quantize weights to 4-bit NF4.

        Args:
            weights: Weight tensor (any shape)

        Returns:
            Quantized weights (uint8), scale factor, offset
        """
        # Normalize weights to N(0,1) distribution
        mean = weights.mean()
        std = weights.std()

        normalized_weights = (weights - mean) / (std + 1e-8)

        # Clip to [-1, 1] range (covers >99.7% of normal distribution)
        normalized_weights = torch.clamp(normalized_weights, -1.0, 1.0)

        # Find nearest bin for each weight
        # Use broadcasting to compute distances to all bins
        distances = torch.abs(
            normalized_weights.unsqueeze(-1) - self.nf4_bins.to(weights.device)
        )
        quantized_indices = torch.argmin(distances, dim=-1)

        # Store as uint8 (only lower 4 bits used)
        quantized_weights = quantized_indices.to(torch.uint8)

        # Return quantized weights + normalization params
        return quantized_weights, mean, std

    def dequantize(self, quantized_weights, mean, std):
        """
        Dequantize NF4 weights back to float.

        Args:
            quantized_weights: Quantized indices (uint8)
            mean: Original mean
            std: Original std

        Returns:
            Dequantized weights
        """
        # Map indices to NF4 values
        nf4_values = self.nf4_bins.to(quantized_weights.device)[quantized_weights]

        # Denormalize
        dequantized = nf4_values * std + mean

        return dequantized


# Test NF4 quantization
quantizer = NF4Quantizer()

# Simulate normally distributed weights
weights = torch.randn(1000, 768) * 0.02  # Typical LLM weight distribution

print(f"Original weights: mean={weights.mean():.6f}, std={weights.std():.6f}")
print(f"Original size: {weights.numel() * 4 / 1e6:.2f} MB (float32)")

# Quantize
quantized, mean, std = quantizer.quantize(weights)
print(f"\nQuantized size: {quantized.numel() * 0.5 / 1e6:.2f} MB (4-bit)")
print(f"Compression ratio: {8.0:.1f}x")

# Dequantize
dequantized = quantizer.dequantize(quantized, mean, std)

# Measure error
mse = torch.mean((weights - dequantized) ** 2)
relative_error = mse / torch.var(weights)

print(f"\nQuantization error:")
print(f"  MSE: {mse:.8f}")
print(f"  Relative error: {relative_error:.4f}")

# Compare with uniform quantization
uniform_scale = weights.max() / 15  # 4-bit: 0-15
uniform_quantized = torch.round(weights / uniform_scale).clamp(0, 15)
uniform_dequantized = uniform_quantized * uniform_scale
uniform_mse = torch.mean((weights - uniform_dequantized) ** 2)

print(f"\nComparison with uniform quantization:")
print(f"  NF4 MSE: {mse:.8f}")
print(f"  Uniform MSE: {uniform_mse:.8f}")
print(f"  NF4 improvement: {uniform_mse/mse:.2f}x better")

Why NF4 Works Better:

Neural network weights typically follow a normal distribution. NF4 places more quantization bins near zero (where most weights are) and fewer bins at extremes.

Uniform quantization: Equal spacing, wastes precision NF4 quantization: Denser bins near zero, optimal for normal distributions

Result: ~2x lower quantization error for same bit width!

2. Double Quantization

QLoRA quantizes the quantization constants themselves to save memory:

python
class DoubleQuantization:
    """
    Double quantization: quantize the quantization constants.

    First quantization: Weights -> 4-bit + FP32 constants (per block)
    Second quantization: FP32 constants -> 8-bit + global FP32 constant
    """

    def __init__(self, block_size=64):
        """
        Args:
            block_size: Size of blocks for first quantization
        """
        self.block_size = block_size
        self.nf4 = NF4Quantizer()

    def quantize(self, weights):
        """
        Perform double quantization.

        Args:
            weights: Weight tensor (2D)

        Returns:
            Doubly quantized representation
        """
        # Reshape into blocks
        original_shape = weights.shape
        flat_weights = weights.flatten()

        # Pad to multiple of block_size
        padding = (self.block_size - flat_weights.numel() % self.block_size) % self.block_size
        if padding > 0:
            flat_weights = torch.cat([flat_weights, torch.zeros(padding)])

        blocks = flat_weights.reshape(-1, self.block_size)
        num_blocks = blocks.shape[0]

        # First quantization: each block -> 4-bit + constants
        quantized_blocks = []
        block_means = []
        block_stds = []

        for block in blocks:
            q_block, mean, std = self.nf4.quantize(block)
            quantized_blocks.append(q_block)
            block_means.append(mean)
            block_stds.append(std)

        # Stack quantized blocks
        quantized_data = torch.stack(quantized_blocks)

        # Convert constants to tensors
        block_means = torch.tensor(block_means)
        block_stds = torch.tensor(block_stds)

        # Second quantization: quantize the constants
        # Use 8-bit uniform quantization for constants
        mean_min, mean_max = block_means.min(), block_means.max()
        std_min, std_max = block_stds.min(), block_stds.max()

        # Quantize means to 8-bit
        mean_scale = (mean_max - mean_min) / 255
        quantized_means = torch.round((block_means - mean_min) / mean_scale).to(torch.uint8)

        # Quantize stds to 8-bit
        std_scale = (std_max - std_min) / 255
        quantized_stds = torch.round((block_stds - std_min) / std_scale).to(torch.uint8)

        return {
            'quantized_weights': quantized_data,  # 4-bit per element
            'quantized_means': quantized_means,   # 8-bit per block
            'quantized_stds': quantized_stds,     # 8-bit per block
            'mean_offset': mean_min,              # FP32 (single value)
            'mean_scale': mean_scale,             # FP32 (single value)
            'std_offset': std_min,                # FP32 (single value)
            'std_scale': std_scale,               # FP32 (single value)
            'original_shape': original_shape,
            'padding': padding
        }

    def dequantize(self, quantized_dict):
        """Dequantize back to original weights."""
        # Dequantize constants
        means = quantized_dict['quantized_means'].float() * quantized_dict['mean_scale'] + quantized_dict['mean_offset']
        stds = quantized_dict['quantized_stds'].float() * quantized_dict['std_scale'] + quantized_dict['std_offset']

        # Dequantize blocks
        dequantized_blocks = []
        for i, q_block in enumerate(quantized_dict['quantized_weights']):
            dequantized_block = self.nf4.dequantize(q_block, means[i], stds[i])
            dequantized_blocks.append(dequantized_block)

        # Concatenate blocks
        dequantized_flat = torch.cat(dequantized_blocks)

        # Remove padding
        if quantized_dict['padding'] > 0:
            dequantized_flat = dequantized_flat[:-quantized_dict['padding']]

        # Reshape to original
        dequantized = dequantized_flat.reshape(quantized_dict['original_shape'])

        return dequantized

    def compute_memory_saved(self, weights):
        """Calculate memory savings from double quantization."""
        # Original: FP32
        original_memory = weights.numel() * 4

        # Single quantization: 4-bit + FP32 constants per block
        num_blocks = (weights.numel() + self.block_size - 1) // self.block_size
        single_quant_memory = weights.numel() * 0.5 + num_blocks * 8  # 2 FP32 per block

        # Double quantization: 4-bit + 8-bit constants + 4 global FP32
        double_quant_memory = weights.numel() * 0.5 + num_blocks * 2 + 16  # 2 uint8 per block + 4 FP32

        print(f"\nMemory comparison for {weights.numel():,} parameters:")
        print(f"  Original (FP32): {original_memory/1e6:.2f} MB")
        print(f"  Single quantization: {single_quant_memory/1e6:.2f} MB")
        print(f"  Double quantization: {double_quant_memory/1e6:.2f} MB")
        print(f"  Savings vs single: {(single_quant_memory - double_quant_memory)/1e6:.2f} MB")


# Test double quantization
dq = DoubleQuantization(block_size=64)
weights = torch.randn(4096, 4096) * 0.02

quantized_dict = dq.quantize(weights)
dequantized = dq.dequantize(quantized_dict)

error = torch.mean((weights - dequantized) ** 2)
print(f"Double quantization error: {error:.8f}")

dq.compute_memory_saved(weights)

Double Quantization Savings:

For a 65B parameter model:

  • Single quantization: 32.5 GB (4-bit weights) + 2 GB (FP32 constants) = 34.5 GB
  • Double quantization: 32.5 GB (4-bit weights) + 0.5 GB (8-bit constants) = 33 GB

Savings: ~1.5 GB (4.3% additional reduction)

May seem small, but crucial for fitting large models on consumer GPUs!

3. Complete QLoRA Implementation

python
class QLoRALinear(torch.nn.Module):
    """
    Linear layer with QLoRA: 4-bit quantized base weights + LoRA adapters.
    """

    def __init__(
        self,
        in_features,
        out_features,
        rank=8,
        alpha=16,
        dropout=0.1,
        quantize_base=True
    ):
        """
        Args:
            in_features: Input dimension
            out_features: Output dimension
            rank: LoRA rank
            alpha: LoRA alpha
            dropout: Dropout
            quantize_base: Whether to quantize base weights
        """
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.scaling = alpha / rank

        # Base weights (quantized if enabled)
        if quantize_base:
            self.quantizer = DoubleQuantization(block_size=64)
            # Initialize with random weights, then quantize
            base_weights = torch.randn(out_features, in_features) * 0.02
            self.quantized_weight = self.quantizer.quantize(base_weights)
            # Don't store as parameter (not trainable)
            self.register_buffer('_base_weight_quantized', torch.tensor(0))  # Placeholder
        else:
            self.base_weight = torch.nn.Parameter(
                torch.randn(out_features, in_features) * 0.02
            )
            self.base_weight.requires_grad = False

        # LoRA adapters (always FP16/FP32)
        self.lora_A = torch.nn.Parameter(torch.randn(rank, in_features))
        self.lora_B = torch.nn.Parameter(torch.zeros(out_features, rank))
        self.dropout = torch.nn.Dropout(p=dropout)

        # Initialize LoRA A
        torch.nn.init.kaiming_uniform_(self.lora_A, a=np.sqrt(5))

        self.quantize_base = quantize_base

    def get_base_weight(self):
        """Get dequantized base weight."""
        if self.quantize_base:
            return self.quantizer.dequantize(self.quantized_weight)
        else:
            return self.base_weight

    def forward(self, x):
        """
        Forward pass: y = Wx + (α/r)BAx

        Base weight W is quantized (4-bit), LoRA is full precision.

        Args:
            x: Input tensor (..., in_features)

        Returns:
            Output tensor (..., out_features)
        """
        # Base weight computation (dequantized on-the-fly)
        base_weight = self.get_base_weight()
        base_output = torch.nn.functional.linear(x, base_weight)

        # LoRA computation
        x_dropout = self.dropout(x)
        lora_output = (x_dropout @ self.lora_A.T) @ self.lora_B.T
        lora_output = lora_output * self.scaling

        return base_output + lora_output


# Test QLoRA layer
print("Testing QLoRA Linear Layer:")

# Create layers
standard_linear = torch.nn.Linear(4096, 4096)
qlora_linear = QLoRALinear(4096, 4096, rank=8, quantize_base=True)

# Count memory
def get_model_memory(model):
    """Estimate model memory usage."""
    total_mem = 0
    for param in model.parameters():
        total_mem += param.numel() * param.element_size()
    return total_mem

standard_mem = get_model_memory(standard_linear)
qlora_mem = get_model_memory(qlora_linear)

print(f"\nStandard Linear: {standard_mem/1e6:.2f} MB")
print(f"QLoRA Linear: {qlora_mem/1e6:.2f} MB")
print(f"Memory reduction: {standard_mem/qlora_mem:.1f}x")

# Test forward pass
x = torch.randn(2, 10, 4096)
output = qlora_linear(x)
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")

Training with QLoRA

python
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

def create_qlora_model(model_name, lora_rank=8, lora_alpha=16):
    """
    Create a model with QLoRA using HuggingFace's bitsandbytes.

    Args:
        model_name: Model name
        lora_rank: LoRA rank
        lora_alpha: LoRA alpha

    Returns:
        Model with QLoRA
    """
    # 4-bit quantization config
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",  # Use NormalFloat 4-bit
        bnb_4bit_use_double_quant=True,  # Double quantization
        bnb_4bit_compute_dtype=torch.bfloat16  # Compute in bfloat16
    )

    # Load model with quantization
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=quantization_config,
        device_map="auto",  # Automatic device placement
        trust_remote_code=True
    )

    # Add LoRA adapters using PEFT library
    from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

    # Prepare model for k-bit training
    model = prepare_model_for_kbit_training(model)

    # Configure LoRA
    lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=lora_alpha,
        target_modules=["q_proj", "v_proj"],  # Apply to Q and V
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )

    # Add LoRA
    model = get_peft_model(model, lora_config)

    # Print trainable parameters
    model.print_trainable_parameters()

    return model


# Example usage (commented out - requires actual model)
# model = create_qlora_model("meta-llama/Llama-2-7b-hf", lora_rank=16)

# Training would use standard PyTorch/HuggingFace training loop
# Only LoRA parameters are updated; quantized base stays frozen

QLoRA Training Considerations:

  1. Slower than FP16: Dequantization adds ~20-30% overhead
  2. Requires bitsandbytes: Install with
    pip install bitsandbytes
  3. GPU compatibility: Requires modern GPUs (Ampere/Ada/Hopper)
  4. Slight accuracy loss: 4-bit causes minor degradation (<1% typically)
  5. Memory-compute tradeoff: Saves memory but slightly slower

Worth it for enabling fine-tuning that wouldn't otherwise fit!

QLoRA vs Standard LoRA

python
import pandas as pd

comparison = pd.DataFrame({
    'Aspect': [
        'Base weights',
        'LoRA adapters',
        'Memory (7B model)',
        'Memory (65B model)',
        'Training speed',
        'Accuracy',
        'GPU required (65B)',
        'Typical use case'
    ],
    'Standard LoRA': [
        'FP16/BF16',
        'FP16/BF16',
        '~14 GB',
        '~130 GB',
        'Fast',
        'Best',
        'Multiple A100 80GB',
        'Smaller models (&lt;13B)'
    ],
    'QLoRA': [
        'NF4 (4-bit)',
        'FP16/BF16',
        '~5 GB',
        '~33 GB',
        'Slower (~1.3x)',
        'Near-best (< 1% drop)',
        'Single A6000/3090',
        'Large models (65B+)'
    ]
})

print(comparison.to_string(index=False))

Summary

QLoRA democratizes LLM fine-tuning through three innovations:

  1. NF4 Quantization: 4-bit quantization optimized for normal distributions
  2. Double Quantization: Quantize the quantization constants to save more memory
  3. Paged Optimizers: Efficient memory management for large models

Result: Fine-tune 65B models on consumer GPUs with minimal accuracy loss.

QLoRA opened fine-tuning to individual researchers and small companies, accelerating LLM research and applications.