Back
advanced
Fine-Tuning Techniques

LoRA: Low-Rank Adaptation

Deep dive into LoRA (Low-Rank Adaptation), the most popular parameter-efficient fine-tuning method. Complete implementation from scratch with training and inference.

25 min read· LoRA· PEFT· Fine-Tuning· Low-Rank

LoRA: Low-Rank Adaptation

LoRA (Low-Rank Adaptation) has become the gold standard for parameter-efficient fine-tuning. It achieves comparable results to full fine-tuning while updating less than 1% of parameters.

The Core Idea

LoRA is based on a simple insight: fine-tuning weight updates have low intrinsic rank.

Instead of updating the full weight matrix W, LoRA represents the update as a low-rank decomposition:

W' = W + ΔW
where ΔW = BA

W: d × d (frozen)
B: d × r (trainable)
A: r × d (trainable)
r ≪ d (e.g., r=8, d=768)

Why This Works:

During fine-tuning, weight updates ΔW often have low intrinsic rank - they can be well-approximated by a low-rank matrix. Instead of learning d × d parameters, we learn (d × r + r × d) parameters where r ≪ d.

For a 768-dimensional layer with r=8:

  • Full update: 768 × 768 = 589,824 parameters
  • LoRA update: 768 × 8 + 8 × 768 = 12,288 parameters
  • 48x reduction!

Mathematical Foundation

Weight Update Decomposition

For a pre-trained weight matrix W₀ ∈ ℝ^(d×k), the modified forward pass is:

h = W₀x + ΔWx = W₀x + BAx

where:
- W₀ is frozen
- B ∈ ℝ^(d×r) is trainable
- A ∈ ℝ^(r×k) is trainable
- r << min(d, k)

Initialization

  • A: Random Gaussian initialization
  • B: Zero initialization (so ΔW = BA = 0 at start)

This ensures the adapted model starts at the pre-trained state.

Scaling Factor

LoRA uses a scaling factor α/r to control the adaptation magnitude:

h = W₀x + (α/r)BAx

Typical: α = r (so scaling = 1) or α = 2r

Complete Implementation

python
import torch
import torch.nn as nn
import math

class LoRALayer(nn.Module):
    """
    LoRA: Low-Rank Adaptation layer.

    Adds trainable low-rank decomposition to a frozen weight matrix.
    """

    def __init__(
        self,
        in_features,
        out_features,
        rank=8,
        alpha=16,
        dropout=0.1
    ):
        """
        Args:
            in_features: Input dimension
            out_features: Output dimension
            rank: Rank of decomposition (r)
            alpha: Scaling factor
            dropout: Dropout probability
        """
        super().__init__()

        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank

        # LoRA matrices
        # A: random init
        self.lora_A = nn.Parameter(torch.randn(rank, in_features))
        # B: zero init
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))

        # Dropout
        self.dropout = nn.Dropout(p=dropout)

        # Initialize A with Kaiming uniform
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))

    def forward(self, x):
        """
        Compute LoRA update: (α/r) * B @ A @ x

        Args:
            x: Input tensor (..., in_features)

        Returns:
            LoRA update (..., out_features)
        """
        # Apply dropout to input
        x = self.dropout(x)

        # Compute low-rank update: B @ (A @ x)
        # This is more efficient than computing (B @ A) @ x
        result = (x @ self.lora_A.T)  # (..., rank)
        result = (result @ self.lora_B.T)  # (..., out_features)

        # Apply scaling
        result = result * self.scaling

        return result


class LinearWithLoRA(nn.Module):
    """
    Linear layer with LoRA adaptation.

    Combines frozen pre-trained weights with trainable LoRA parameters.
    """

    def __init__(
        self,
        in_features,
        out_features,
        rank=8,
        alpha=16,
        dropout=0.1,
        bias=True
    ):
        """
        Args:
            in_features: Input dimension
            out_features: Output dimension
            rank: LoRA rank
            alpha: LoRA scaling factor
            dropout: Dropout probability
            bias: Whether to include bias
        """
        super().__init__()

        # Frozen pre-trained linear layer
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        # Freeze it
        for param in self.linear.parameters():
            param.requires_grad = False

        # LoRA adaptation
        self.lora = LoRALayer(in_features, out_features, rank, alpha, dropout)

        # Whether LoRA is enabled
        self.merged = False

    def forward(self, x):
        """
        Forward pass: W₀x + BAx

        Args:
            x: Input tensor

        Returns:
            Output tensor
        """
        if self.merged:
            # LoRA has been merged into weights
            return self.linear(x)
        else:
            # Separate LoRA computation
            return self.linear(x) + self.lora(x)

    def merge_lora(self):
        """
        Merge LoRA weights into the base linear layer for inference.

        After merging: W' = W + BA
        No inference overhead!
        """
        if not self.merged:
            # Compute ΔW = (α/r) * B @ A
            delta_w = (self.lora.lora_B @ self.lora.lora_A) * self.lora.scaling

            # Add to base weights
            self.linear.weight.data += delta_w

            self.merged = True
            print("LoRA weights merged!")

    def unmerge_lora(self):
        """
        Unmerge LoRA weights (useful for switching tasks).
        """
        if self.merged:
            delta_w = (self.lora.lora_B @ self.lora.lora_A) * self.lora.scaling
            self.linear.weight.data -= delta_w
            self.merged = False
            print("LoRA weights unmerged!")


# Test LoRA layer
in_dim, out_dim = 768, 768
x = torch.randn(2, 10, in_dim)  # (batch, seq, dim)

# Standard linear layer
standard_linear = nn.Linear(in_dim, out_dim)
standard_params = sum(p.numel() for p in standard_linear.parameters())

# LoRA linear layer
lora_linear = LinearWithLoRA(in_dim, out_dim, rank=8)
trainable_params = sum(p.numel() for p in lora_linear.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in lora_linear.parameters())

print(f"Standard Linear: {standard_params:,} parameters")
print(f"LoRA Linear: {trainable_params:,} trainable ({100*trainable_params/total_params:.2f}%)")
print(f"Reduction: {standard_params/trainable_params:.1f}x")

# Forward pass
output = lora_linear(x)
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")

# Test merging
print("\nBefore merge:")
output_before = lora_linear(x)

lora_linear.merge_lora()
output_after = lora_linear(x)

print(f"Outputs equal after merge: {torch.allclose(output_before, output_after, atol=1e-5)}")

LoRA Key Parameters:

  1. Rank (r): Higher rank = more capacity but more parameters

    • Typical: r = 4, 8, 16, 32
    • Start with r=8 for most tasks
  2. Alpha (α): Controls update magnitude

    • Typical: α = r or α = 2r
    • Higher α = stronger adaptation
  3. Dropout: Prevents overfitting

    • Typical: 0.05 - 0.1
    • Higher for smaller datasets

Applying LoRA to Transformers

python
from transformers import GPT2LMHeadModel, GPT2Config
import re

class GPT2WithLoRA(nn.Module):
    """
    GPT-2 with LoRA applied to attention layers.

    Typically apply LoRA to:
    - Query and Value projections (Q, V)
    - Or all four: Query, Key, Value, Output (Q, K, V, O)
    """

    def __init__(
        self,
        model_name='gpt2',
        lora_rank=8,
        lora_alpha=16,
        lora_dropout=0.1,
        target_modules=['c_attn', 'c_proj']  # Q,K,V and output projection
    ):
        """
        Args:
            model_name: Base GPT-2 model
            lora_rank: LoRA rank
            lora_alpha: LoRA alpha
            lora_dropout: LoRA dropout
            target_modules: Which modules to apply LoRA to
        """
        super().__init__()

        # Load base model
        self.base_model = GPT2LMHeadModel.from_pretrained(model_name)

        # Apply LoRA to target modules
        self.apply_lora_to_model(
            self.base_model,
            target_modules,
            lora_rank,
            lora_alpha,
            lora_dropout
        )

        self.print_trainable_parameters()

    def apply_lora_to_model(self, model, target_modules, rank, alpha, dropout):
        """
        Replace target Linear layers with LinearWithLoRA.

        Args:
            model: The model to modify
            target_modules: Names of modules to replace
            rank, alpha, dropout: LoRA parameters
        """
        # Freeze all parameters first
        for param in model.parameters():
            param.requires_grad = False

        # Find and replace target modules
        for name, module in model.named_modules():
            # Check if this module should have LoRA
            if any(target in name for target in target_modules):
                if isinstance(module, nn.Linear):
                    # Get parent module and attribute name
                    parent_name = '.'.join(name.split('.')[:-1])
                    attr_name = name.split('.')[-1]

                    parent = model.get_submodule(parent_name) if parent_name else model

                    # Create LoRA layer
                    lora_layer = LinearWithLoRA(
                        module.in_features,
                        module.out_features,
                        rank=rank,
                        alpha=alpha,
                        dropout=dropout,
                        bias=module.bias is not None
                    )

                    # Copy pre-trained weights
                    lora_layer.linear.weight.data = module.weight.data.clone()
                    if module.bias is not None:
                        lora_layer.linear.bias.data = module.bias.data.clone()

                    # Replace module
                    setattr(parent, attr_name, lora_layer)

                    print(f"Applied LoRA to: {name}")

    def print_trainable_parameters(self):
        """Print trainable parameter statistics."""
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.parameters())

        print(f"\n{'='*50}")
        print(f"Trainable parameters: {trainable:,}")
        print(f"Total parameters: {total:,}")
        print(f"Trainable: {100 * trainable / total:.3f}%")
        print(f"{'='*50}\n")

    def forward(self, input_ids, attention_mask=None, labels=None):
        """Forward pass through LoRA-adapted model."""
        return self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

    def merge_and_save(self, path):
        """
        Merge LoRA weights and save the model.

        Args:
            path: Path to save merged model
        """
        # Merge all LoRA layers
        for module in self.modules():
            if isinstance(module, LinearWithLoRA):
                module.merge_lora()

        # Save merged model
        self.base_model.save_pretrained(path)
        print(f"Merged model saved to {path}")


# Example: Create GPT-2 with LoRA
print("Creating GPT-2 with LoRA...")
model = GPT2WithLoRA(
    model_name='gpt2',
    lora_rank=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=['c_attn', 'c_proj']
)

Training with LoRA

python
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer, get_linear_schedule_with_warmup

class LoRATrainer:
    """
    Trainer for LoRA fine-tuning.
    """

    def __init__(
        self,
        model,
        tokenizer,
        learning_rate=1e-4,
        weight_decay=0.01,
        warmup_ratio=0.1
    ):
        """
        Args:
            model: Model with LoRA layers
            tokenizer: Tokenizer
            learning_rate: Learning rate
            weight_decay: Weight decay
            warmup_ratio: Warmup ratio
        """
        self.model = model
        self.tokenizer = tokenizer
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

        # Only optimize LoRA parameters
        lora_params = [p for p in model.parameters() if p.requires_grad]

        self.optimizer = torch.optim.AdamW(
            lora_params,
            lr=learning_rate,
            weight_decay=weight_decay
        )

        print(f"Optimizing {len(lora_params)} parameter groups")

    def train(
        self,
        train_loader,
        val_loader,
        epochs=3,
        gradient_accumulation_steps=1,
        max_grad_norm=1.0
    ):
        """
        Train the model with LoRA.

        Args:
            train_loader: Training data loader
            val_loader: Validation data loader
            epochs: Number of epochs
            gradient_accumulation_steps: Gradient accumulation
            max_grad_norm: Gradient clipping
        """
        total_steps = len(train_loader) * epochs // gradient_accumulation_steps

        # Learning rate scheduler
        scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=int(0.1 * total_steps),
            num_training_steps=total_steps
        )

        best_val_loss = float('inf')
        global_step = 0

        for epoch in range(epochs):
            # Training
            self.model.train()
            train_loss = 0
            self.optimizer.zero_grad()

            for step, batch in enumerate(train_loader):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)

                # Forward pass
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                loss = outputs.loss / gradient_accumulation_steps
                loss.backward()

                train_loss += loss.item() * gradient_accumulation_steps

                # Update weights
                if (step + 1) % gradient_accumulation_steps == 0:
                    # Clip gradients
                    torch.nn.utils.clip_grad_norm_(
                        [p for p in self.model.parameters() if p.requires_grad],
                        max_grad_norm
                    )

                    self.optimizer.step()
                    scheduler.step()
                    self.optimizer.zero_grad()

                    global_step += 1

                    if global_step % 100 == 0:
                        avg_loss = train_loss / (step + 1)
                        lr = scheduler.get_last_lr()[0]
                        print(f"Step {global_step}: Loss={avg_loss:.4f}, LR={lr:.2e}")

            avg_train_loss = train_loss / len(train_loader)

            # Validation
            val_loss = self.validate(val_loader)

            print(f"\nEpoch {epoch + 1}/{epochs}")
            print(f"  Train Loss: {avg_train_loss:.4f}")
            print(f"  Val Loss: {val_loss:.4f}")

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_checkpoint('best_lora_model')
                print("  Saved best model!")

    def validate(self, val_loader):
        """Validate the model."""
        self.model.eval()
        total_loss = 0

        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)

                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                total_loss += outputs.loss.item()

        return total_loss / len(val_loader)

    def save_checkpoint(self, path):
        """Save LoRA checkpoint."""
        # Save only LoRA parameters
        lora_state = {
            name: param for name, param in self.model.named_parameters()
            if param.requires_grad
        }

        torch.save({
            'lora_state_dict': lora_state,
            'optimizer_state_dict': self.optimizer.state_dict()
        }, f'{path}.pt')

        print(f"Checkpoint saved to {path}.pt")


# Example training setup
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# trainer = LoRATrainer(model, tokenizer, learning_rate=1e-4)
# trainer.train(train_loader, val_loader, epochs=3)

LoRA Training Tips:

  1. Higher learning rate: LoRA can use higher LR than full fine-tuning (1e-4 vs 1e-5)
  2. Rank selection: Start with r=8, increase if underfitting
  3. Target modules: Apply to attention (Q, V) or all projections (Q, K, V, O)
  4. Regularization: LoRA inherently regularizes, but dropout still helps
  5. Batch size: Can use larger batches due to memory savings

Inference and Deployment

python
class LoRAInference:
    """
    LoRA inference utilities.
    """

    @staticmethod
    def load_lora_checkpoint(base_model, checkpoint_path):
        """
        Load LoRA checkpoint into model.

        Args:
            base_model: Base model with LoRA layers
            checkpoint_path: Path to LoRA checkpoint
        """
        checkpoint = torch.load(checkpoint_path)
        lora_state = checkpoint['lora_state_dict']

        # Load only LoRA parameters
        model_dict = base_model.state_dict()
        model_dict.update(lora_state)
        base_model.load_state_dict(model_dict)

        print("LoRA checkpoint loaded!")

    @staticmethod
    def merge_for_deployment(model):
        """
        Merge LoRA weights for deployment.

        After merging, no inference overhead!
        """
        for module in model.modules():
            if isinstance(module, LinearWithLoRA):
                module.merge_lora()

        print("All LoRA layers merged for deployment!")

    @staticmethod
    def generate_text(model, tokenizer, prompt, max_length=100):
        """
        Generate text with LoRA model.

        Args:
            model: LoRA model
            tokenizer: Tokenizer
            prompt: Input prompt
            max_length: Maximum length

        Returns:
            Generated text
        """
        model.eval()
        device = next(model.parameters()).device

        # Encode prompt
        input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

        # Generate
        with torch.no_grad():
            output_ids = model.base_model.generate(
                input_ids,
                max_length=max_length,
                num_return_sequences=1,
                temperature=0.7,
                top_p=0.9,
                do_sample=True
            )

        # Decode
        generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        return generated_text


# Example inference
# model = GPT2WithLoRA('gpt2', lora_rank=8)
# LoRAInference.load_lora_checkpoint(model, 'best_lora_model.pt')
# LoRAInference.merge_for_deployment(model)
# text = LoRAInference.generate_text(model, tokenizer, "Once upon a time")

Summary

LoRA achieves remarkable parameter efficiency through low-rank decomposition:

  1. Core idea: Weight updates ΔW = BA where r ≪ d
  2. Benefits: 10-100x parameter reduction, no inference overhead after merging
  3. Application: Attention layers (Q, V or Q, K, V, O)
  4. Training: Higher learning rates, faster convergence
  5. Deployment: Merge weights or swap LoRA modules for multi-task

LoRA has become the standard for LLM fine-tuning due to its simplicity and effectiveness.