LoRA: Low-Rank Adaptation
LoRA (Low-Rank Adaptation) has become the gold standard for parameter-efficient fine-tuning. It achieves comparable results to full fine-tuning while updating less than 1% of parameters.
The Core Idea
LoRA is based on a simple insight: fine-tuning weight updates have low intrinsic rank.
Instead of updating the full weight matrix W, LoRA represents the update as a low-rank decomposition:
W' = W + ΔW
where ΔW = BA
W: d × d (frozen)
B: d × r (trainable)
A: r × d (trainable)
r ≪ d (e.g., r=8, d=768)
Why This Works:
During fine-tuning, weight updates ΔW often have low intrinsic rank - they can be well-approximated by a low-rank matrix. Instead of learning d × d parameters, we learn (d × r + r × d) parameters where r ≪ d.
For a 768-dimensional layer with r=8:
- Full update: 768 × 768 = 589,824 parameters
- LoRA update: 768 × 8 + 8 × 768 = 12,288 parameters
- 48x reduction!
Mathematical Foundation
Weight Update Decomposition
For a pre-trained weight matrix W₀ ∈ ℝ^(d×k), the modified forward pass is:
h = W₀x + ΔWx = W₀x + BAx
where:
- W₀ is frozen
- B ∈ ℝ^(d×r) is trainable
- A ∈ ℝ^(r×k) is trainable
- r << min(d, k)
Initialization
- A: Random Gaussian initialization
- B: Zero initialization (so ΔW = BA = 0 at start)
This ensures the adapted model starts at the pre-trained state.
Scaling Factor
LoRA uses a scaling factor α/r to control the adaptation magnitude:
h = W₀x + (α/r)BAx
Typical: α = r (so scaling = 1) or α = 2r
Complete Implementation
import torch
import torch.nn as nn
import math
class LoRALayer(nn.Module):
"""
LoRA: Low-Rank Adaptation layer.
Adds trainable low-rank decomposition to a frozen weight matrix.
"""
def __init__(
self,
in_features,
out_features,
rank=8,
alpha=16,
dropout=0.1
):
"""
Args:
in_features: Input dimension
out_features: Output dimension
rank: Rank of decomposition (r)
alpha: Scaling factor
dropout: Dropout probability
"""
super().__init__()
self.rank = rank
self.alpha = alpha
self.scaling = alpha / rank
# LoRA matrices
# A: random init
self.lora_A = nn.Parameter(torch.randn(rank, in_features))
# B: zero init
self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
# Dropout
self.dropout = nn.Dropout(p=dropout)
# Initialize A with Kaiming uniform
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
def forward(self, x):
"""
Compute LoRA update: (α/r) * B @ A @ x
Args:
x: Input tensor (..., in_features)
Returns:
LoRA update (..., out_features)
"""
# Apply dropout to input
x = self.dropout(x)
# Compute low-rank update: B @ (A @ x)
# This is more efficient than computing (B @ A) @ x
result = (x @ self.lora_A.T) # (..., rank)
result = (result @ self.lora_B.T) # (..., out_features)
# Apply scaling
result = result * self.scaling
return result
class LinearWithLoRA(nn.Module):
"""
Linear layer with LoRA adaptation.
Combines frozen pre-trained weights with trainable LoRA parameters.
"""
def __init__(
self,
in_features,
out_features,
rank=8,
alpha=16,
dropout=0.1,
bias=True
):
"""
Args:
in_features: Input dimension
out_features: Output dimension
rank: LoRA rank
alpha: LoRA scaling factor
dropout: Dropout probability
bias: Whether to include bias
"""
super().__init__()
# Frozen pre-trained linear layer
self.linear = nn.Linear(in_features, out_features, bias=bias)
# Freeze it
for param in self.linear.parameters():
param.requires_grad = False
# LoRA adaptation
self.lora = LoRALayer(in_features, out_features, rank, alpha, dropout)
# Whether LoRA is enabled
self.merged = False
def forward(self, x):
"""
Forward pass: W₀x + BAx
Args:
x: Input tensor
Returns:
Output tensor
"""
if self.merged:
# LoRA has been merged into weights
return self.linear(x)
else:
# Separate LoRA computation
return self.linear(x) + self.lora(x)
def merge_lora(self):
"""
Merge LoRA weights into the base linear layer for inference.
After merging: W' = W + BA
No inference overhead!
"""
if not self.merged:
# Compute ΔW = (α/r) * B @ A
delta_w = (self.lora.lora_B @ self.lora.lora_A) * self.lora.scaling
# Add to base weights
self.linear.weight.data += delta_w
self.merged = True
print("LoRA weights merged!")
def unmerge_lora(self):
"""
Unmerge LoRA weights (useful for switching tasks).
"""
if self.merged:
delta_w = (self.lora.lora_B @ self.lora.lora_A) * self.lora.scaling
self.linear.weight.data -= delta_w
self.merged = False
print("LoRA weights unmerged!")
# Test LoRA layer
in_dim, out_dim = 768, 768
x = torch.randn(2, 10, in_dim) # (batch, seq, dim)
# Standard linear layer
standard_linear = nn.Linear(in_dim, out_dim)
standard_params = sum(p.numel() for p in standard_linear.parameters())
# LoRA linear layer
lora_linear = LinearWithLoRA(in_dim, out_dim, rank=8)
trainable_params = sum(p.numel() for p in lora_linear.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in lora_linear.parameters())
print(f"Standard Linear: {standard_params:,} parameters")
print(f"LoRA Linear: {trainable_params:,} trainable ({100*trainable_params/total_params:.2f}%)")
print(f"Reduction: {standard_params/trainable_params:.1f}x")
# Forward pass
output = lora_linear(x)
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")
# Test merging
print("\nBefore merge:")
output_before = lora_linear(x)
lora_linear.merge_lora()
output_after = lora_linear(x)
print(f"Outputs equal after merge: {torch.allclose(output_before, output_after, atol=1e-5)}")
LoRA Key Parameters:
-
Rank (r): Higher rank = more capacity but more parameters
- Typical: r = 4, 8, 16, 32
- Start with r=8 for most tasks
-
Alpha (α): Controls update magnitude
- Typical: α = r or α = 2r
- Higher α = stronger adaptation
-
Dropout: Prevents overfitting
- Typical: 0.05 - 0.1
- Higher for smaller datasets
Applying LoRA to Transformers
from transformers import GPT2LMHeadModel, GPT2Config
import re
class GPT2WithLoRA(nn.Module):
"""
GPT-2 with LoRA applied to attention layers.
Typically apply LoRA to:
- Query and Value projections (Q, V)
- Or all four: Query, Key, Value, Output (Q, K, V, O)
"""
def __init__(
self,
model_name='gpt2',
lora_rank=8,
lora_alpha=16,
lora_dropout=0.1,
target_modules=['c_attn', 'c_proj'] # Q,K,V and output projection
):
"""
Args:
model_name: Base GPT-2 model
lora_rank: LoRA rank
lora_alpha: LoRA alpha
lora_dropout: LoRA dropout
target_modules: Which modules to apply LoRA to
"""
super().__init__()
# Load base model
self.base_model = GPT2LMHeadModel.from_pretrained(model_name)
# Apply LoRA to target modules
self.apply_lora_to_model(
self.base_model,
target_modules,
lora_rank,
lora_alpha,
lora_dropout
)
self.print_trainable_parameters()
def apply_lora_to_model(self, model, target_modules, rank, alpha, dropout):
"""
Replace target Linear layers with LinearWithLoRA.
Args:
model: The model to modify
target_modules: Names of modules to replace
rank, alpha, dropout: LoRA parameters
"""
# Freeze all parameters first
for param in model.parameters():
param.requires_grad = False
# Find and replace target modules
for name, module in model.named_modules():
# Check if this module should have LoRA
if any(target in name for target in target_modules):
if isinstance(module, nn.Linear):
# Get parent module and attribute name
parent_name = '.'.join(name.split('.')[:-1])
attr_name = name.split('.')[-1]
parent = model.get_submodule(parent_name) if parent_name else model
# Create LoRA layer
lora_layer = LinearWithLoRA(
module.in_features,
module.out_features,
rank=rank,
alpha=alpha,
dropout=dropout,
bias=module.bias is not None
)
# Copy pre-trained weights
lora_layer.linear.weight.data = module.weight.data.clone()
if module.bias is not None:
lora_layer.linear.bias.data = module.bias.data.clone()
# Replace module
setattr(parent, attr_name, lora_layer)
print(f"Applied LoRA to: {name}")
def print_trainable_parameters(self):
"""Print trainable parameter statistics."""
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
print(f"\n{'='*50}")
print(f"Trainable parameters: {trainable:,}")
print(f"Total parameters: {total:,}")
print(f"Trainable: {100 * trainable / total:.3f}%")
print(f"{'='*50}\n")
def forward(self, input_ids, attention_mask=None, labels=None):
"""Forward pass through LoRA-adapted model."""
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
def merge_and_save(self, path):
"""
Merge LoRA weights and save the model.
Args:
path: Path to save merged model
"""
# Merge all LoRA layers
for module in self.modules():
if isinstance(module, LinearWithLoRA):
module.merge_lora()
# Save merged model
self.base_model.save_pretrained(path)
print(f"Merged model saved to {path}")
# Example: Create GPT-2 with LoRA
print("Creating GPT-2 with LoRA...")
model = GPT2WithLoRA(
model_name='gpt2',
lora_rank=8,
lora_alpha=16,
lora_dropout=0.1,
target_modules=['c_attn', 'c_proj']
)
Training with LoRA
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer, get_linear_schedule_with_warmup
class LoRATrainer:
"""
Trainer for LoRA fine-tuning.
"""
def __init__(
self,
model,
tokenizer,
learning_rate=1e-4,
weight_decay=0.01,
warmup_ratio=0.1
):
"""
Args:
model: Model with LoRA layers
tokenizer: Tokenizer
learning_rate: Learning rate
weight_decay: Weight decay
warmup_ratio: Warmup ratio
"""
self.model = model
self.tokenizer = tokenizer
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
# Only optimize LoRA parameters
lora_params = [p for p in model.parameters() if p.requires_grad]
self.optimizer = torch.optim.AdamW(
lora_params,
lr=learning_rate,
weight_decay=weight_decay
)
print(f"Optimizing {len(lora_params)} parameter groups")
def train(
self,
train_loader,
val_loader,
epochs=3,
gradient_accumulation_steps=1,
max_grad_norm=1.0
):
"""
Train the model with LoRA.
Args:
train_loader: Training data loader
val_loader: Validation data loader
epochs: Number of epochs
gradient_accumulation_steps: Gradient accumulation
max_grad_norm: Gradient clipping
"""
total_steps = len(train_loader) * epochs // gradient_accumulation_steps
# Learning rate scheduler
scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=int(0.1 * total_steps),
num_training_steps=total_steps
)
best_val_loss = float('inf')
global_step = 0
for epoch in range(epochs):
# Training
self.model.train()
train_loss = 0
self.optimizer.zero_grad()
for step, batch in enumerate(train_loader):
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].to(self.device)
# Forward pass
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
loss = outputs.loss / gradient_accumulation_steps
loss.backward()
train_loss += loss.item() * gradient_accumulation_steps
# Update weights
if (step + 1) % gradient_accumulation_steps == 0:
# Clip gradients
torch.nn.utils.clip_grad_norm_(
[p for p in self.model.parameters() if p.requires_grad],
max_grad_norm
)
self.optimizer.step()
scheduler.step()
self.optimizer.zero_grad()
global_step += 1
if global_step % 100 == 0:
avg_loss = train_loss / (step + 1)
lr = scheduler.get_last_lr()[0]
print(f"Step {global_step}: Loss={avg_loss:.4f}, LR={lr:.2e}")
avg_train_loss = train_loss / len(train_loader)
# Validation
val_loss = self.validate(val_loader)
print(f"\nEpoch {epoch + 1}/{epochs}")
print(f" Train Loss: {avg_train_loss:.4f}")
print(f" Val Loss: {val_loss:.4f}")
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
self.save_checkpoint('best_lora_model')
print(" Saved best model!")
def validate(self, val_loader):
"""Validate the model."""
self.model.eval()
total_loss = 0
with torch.no_grad():
for batch in val_loader:
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].to(self.device)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
total_loss += outputs.loss.item()
return total_loss / len(val_loader)
def save_checkpoint(self, path):
"""Save LoRA checkpoint."""
# Save only LoRA parameters
lora_state = {
name: param for name, param in self.model.named_parameters()
if param.requires_grad
}
torch.save({
'lora_state_dict': lora_state,
'optimizer_state_dict': self.optimizer.state_dict()
}, f'{path}.pt')
print(f"Checkpoint saved to {path}.pt")
# Example training setup
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# trainer = LoRATrainer(model, tokenizer, learning_rate=1e-4)
# trainer.train(train_loader, val_loader, epochs=3)
LoRA Training Tips:
- Higher learning rate: LoRA can use higher LR than full fine-tuning (1e-4 vs 1e-5)
- Rank selection: Start with r=8, increase if underfitting
- Target modules: Apply to attention (Q, V) or all projections (Q, K, V, O)
- Regularization: LoRA inherently regularizes, but dropout still helps
- Batch size: Can use larger batches due to memory savings
Inference and Deployment
class LoRAInference:
"""
LoRA inference utilities.
"""
@staticmethod
def load_lora_checkpoint(base_model, checkpoint_path):
"""
Load LoRA checkpoint into model.
Args:
base_model: Base model with LoRA layers
checkpoint_path: Path to LoRA checkpoint
"""
checkpoint = torch.load(checkpoint_path)
lora_state = checkpoint['lora_state_dict']
# Load only LoRA parameters
model_dict = base_model.state_dict()
model_dict.update(lora_state)
base_model.load_state_dict(model_dict)
print("LoRA checkpoint loaded!")
@staticmethod
def merge_for_deployment(model):
"""
Merge LoRA weights for deployment.
After merging, no inference overhead!
"""
for module in model.modules():
if isinstance(module, LinearWithLoRA):
module.merge_lora()
print("All LoRA layers merged for deployment!")
@staticmethod
def generate_text(model, tokenizer, prompt, max_length=100):
"""
Generate text with LoRA model.
Args:
model: LoRA model
tokenizer: Tokenizer
prompt: Input prompt
max_length: Maximum length
Returns:
Generated text
"""
model.eval()
device = next(model.parameters()).device
# Encode prompt
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
# Generate
with torch.no_grad():
output_ids = model.base_model.generate(
input_ids,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
top_p=0.9,
do_sample=True
)
# Decode
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return generated_text
# Example inference
# model = GPT2WithLoRA('gpt2', lora_rank=8)
# LoRAInference.load_lora_checkpoint(model, 'best_lora_model.pt')
# LoRAInference.merge_for_deployment(model)
# text = LoRAInference.generate_text(model, tokenizer, "Once upon a time")
Summary
LoRA achieves remarkable parameter efficiency through low-rank decomposition:
- Core idea: Weight updates ΔW = BA where r ≪ d
- Benefits: 10-100x parameter reduction, no inference overhead after merging
- Application: Attention layers (Q, V or Q, K, V, O)
- Training: Higher learning rates, faster convergence
- Deployment: Merge weights or swap LoRA modules for multi-task
LoRA has become the standard for LLM fine-tuning due to its simplicity and effectiveness.