Back
advanced

Fine-Tuning Methods Explained

Deep dive into full fine-tuning, parameter-efficient methods (PEFT), adapters, and prompt tuning. Learn when to use each approach with complete implementations.

22 min read· Fine-Tuning· PEFT· Adapters· Prompt Tuning

Fine-Tuning Methods Explained

Fine-tuning adapts pre-trained models to specific tasks, but there are multiple approaches with different tradeoffs. Let's explore full fine-tuning, parameter-efficient methods, and when to use each.

The Fine-Tuning Spectrum

Modern fine-tuning methods range from updating all parameters to updating less than 1%:

Fine-Tuning Methods by Parameter Count:

  1. Full Fine-Tuning: Update all parameters (100%)
  2. Adapters: Add small trainable layers (~1-5%)
  3. Prefix/Prompt Tuning: Add trainable tokens (~0.1-1%)
  4. LoRA: Low-rank decomposition (~0.1-1%)
  5. BitFit: Only bias terms (~0.1%)

Each method trades off between adaptation quality and computational efficiency.

1. Full Fine-Tuning

Full fine-tuning updates all model parameters for maximum task adaptation.

Implementation

python
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AdamW
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

class FullFineTuner:
    """
    Full fine-tuning: Update all model parameters.

    Pros: Maximum task adaptation, simple to implement
    Cons: High memory, requires large dataset, risk of catastrophic forgetting
    """

    def __init__(self, model_name='gpt2', learning_rate=5e-5):
        """
        Args:
            model_name: Pre-trained model name
            learning_rate: Learning rate for fine-tuning
        """
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Load model and tokenizer
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model.to(self.device)
        self.learning_rate = learning_rate

        # Print trainable parameters
        self.print_trainable_parameters()

    def print_trainable_parameters(self):
        """Print number of trainable parameters."""
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        all_params = sum(p.numel() for p in self.model.parameters())

        print(f"Trainable parameters: {trainable_params:,}")
        print(f"All parameters: {all_params:,}")
        print(f"Trainable: {100 * trainable_params / all_params:.2f}%")

    def train(self, train_loader, val_loader, epochs=3):
        """
        Train the model on task-specific data.

        Args:
            train_loader: Training data loader
            val_loader: Validation data loader
            epochs: Number of training epochs
        """
        # Optimizer for all parameters
        optimizer = AdamW(self.model.parameters(), lr=self.learning_rate)

        best_val_loss = float('inf')

        for epoch in range(epochs):
            # Training phase
            self.model.train()
            train_loss = 0

            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
            for batch in progress_bar:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)

                # Forward pass
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                loss = outputs.loss

                # Backward pass
                optimizer.zero_grad()
                loss.backward()

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

                optimizer.step()

                train_loss += loss.item()
                progress_bar.set_postfix({'loss': loss.item()})

            avg_train_loss = train_loss / len(train_loader)

            # Validation phase
            val_loss = self.validate(val_loader)

            print(f"\nEpoch {epoch+1}/{epochs}")
            print(f"  Average training loss: {avg_train_loss:.4f}")
            print(f"  Validation loss: {val_loss:.4f}")

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_model('best_full_finetuned_model')
                print("  Saved best model!")

    def validate(self, val_loader):
        """Validate the model."""
        self.model.eval()
        val_loss = 0

        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)

                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                val_loss += outputs.loss.item()

        return val_loss / len(val_loader)

    def save_model(self, path):
        """Save the fine-tuned model."""
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)


# Example usage
# tuner = FullFineTuner('gpt2', learning_rate=5e-5)
# tuner.train(train_loader, val_loader, epochs=3)

Full Fine-Tuning Challenges:

  1. Memory: Needs memory for all parameters + gradients + optimizer states (3x model size)
  2. Data: Requires thousands of examples to avoid overfitting
  3. Catastrophic forgetting: Can lose pre-trained knowledge
  4. Cost: Expensive for large models (7B+ parameters)

For GPT-3 (175B params), full fine-tuning requires ~1.4TB of GPU memory!

2. Adapter Layers

Adapters insert small bottleneck layers between transformer layers, freezing the original model.

Architecture

Original Layer → Adapter → Next Layer

Adapter structure:
  Input (d) → Down-project (r) → Non-linearity → Up-project (d) → Residual

where r << d (e.g., r=64, d=768)

Implementation

python
class AdapterLayer(nn.Module):
    """
    Adapter layer: bottleneck architecture for parameter-efficient tuning.

    Reduces dimension to bottleneck, applies non-linearity, then expands back.
    """

    def __init__(self, input_dim, bottleneck_dim=64, dropout=0.1):
        """
        Args:
            input_dim: Input dimension (model hidden size)
            bottleneck_dim: Bottleneck dimension (much smaller)
            dropout: Dropout probability
        """
        super().__init__()

        # Down-projection
        self.down_project = nn.Linear(input_dim, bottleneck_dim)

        # Non-linearity
        self.activation = nn.GELU()

        # Up-projection
        self.up_project = nn.Linear(bottleneck_dim, input_dim)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # Initialize near-identity at start
        nn.init.zeros_(self.up_project.weight)
        nn.init.zeros_(self.up_project.bias)

    def forward(self, x):
        """
        Args:
            x: Input tensor (batch, seq_len, input_dim)

        Returns:
            Output with residual connection
        """
        # Residual connection
        residual = x

        # Adapter transformation
        x = self.down_project(x)
        x = self.activation(x)
        x = self.up_project(x)
        x = self.dropout(x)

        # Add residual
        return x + residual


class GPT2WithAdapters(nn.Module):
    """
    GPT-2 with adapter layers inserted after each transformer block.
    """

    def __init__(self, base_model_name='gpt2', bottleneck_dim=64):
        """
        Args:
            base_model_name: Base model to adapt
            bottleneck_dim: Adapter bottleneck dimension
        """
        super().__init__()

        # Load base model
        self.base_model = GPT2LMHeadModel.from_pretrained(base_model_name)

        # Freeze all base model parameters
        for param in self.base_model.parameters():
            param.requires_grad = False

        # Get model dimension
        config = self.base_model.config
        hidden_size = config.n_embd

        # Add adapters after each transformer block
        self.adapters = nn.ModuleList([
            AdapterLayer(hidden_size, bottleneck_dim)
            for _ in range(config.n_layer)
        ])

        print(f"Added {len(self.adapters)} adapter layers")
        self.print_trainable_parameters()

    def print_trainable_parameters(self):
        """Print trainable parameters."""
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.parameters())

        print(f"Trainable parameters: {trainable:,} ({100*trainable/total:.2f}%)")
        print(f"Total parameters: {total:,}")

    def forward(self, input_ids, attention_mask=None, labels=None):
        """
        Forward pass with adapters.

        Args:
            input_ids: Input token IDs
            attention_mask: Attention mask
            labels: Labels for language modeling
        """
        # Get base model hidden states
        outputs = self.base_model.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )

        # Apply adapters to each layer's output
        hidden_states = outputs.hidden_states

        # Process through adapters
        for i, adapter in enumerate(self.adapters):
            # Get layer output (skip input embedding)
            layer_output = hidden_states[i + 1]

            # Apply adapter
            adapted_output = adapter(layer_output)

            # Use adapted output for next layer
            if i < len(self.adapters) - 1:
                hidden_states = (
                    hidden_states[:i+2] +
                    (adapted_output,) +
                    hidden_states[i+2:]
                )
            else:
                # Last layer
                final_hidden = adapted_output

        # Apply LM head
        logits = self.base_model.lm_head(final_hidden)

        # Calculate loss if labels provided
        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )

        return {'loss': loss, 'logits': logits}


# Example: Compare parameter counts
print("Full Fine-Tuning:")
full_model = GPT2LMHeadModel.from_pretrained('gpt2')
total_params = sum(p.numel() for p in full_model.parameters())
print(f"Trainable parameters: {total_params:,} (100%)\n")

print("Adapter Tuning:")
adapter_model = GPT2WithAdapters('gpt2', bottleneck_dim=64)

Adapter Advantages:

  1. Efficiency: Only ~1-5% parameters trainable
  2. Modularity: Can swap adapters for different tasks
  3. No forgetting: Base model frozen, preserves pre-trained knowledge
  4. Storage: Store multiple task adapters efficiently

Perfect for multi-task scenarios where you need specialized models for different domains.

3. Prefix Tuning / Prompt Tuning

Add trainable continuous prompts (virtual tokens) while freezing the model.

Implementation

python
class PrefixTuning(nn.Module):
    """
    Prefix tuning: Add trainable prefix embeddings before input.

    Instead of adding adapter layers, we add trainable "virtual tokens"
    that guide the model's behavior.
    """

    def __init__(self, base_model_name='gpt2', prefix_length=20):
        """
        Args:
            base_model_name: Base model name
            prefix_length: Number of prefix tokens to add
        """
        super().__init__()

        # Load base model
        self.base_model = GPT2LMHeadModel.from_pretrained(base_model_name)

        # Freeze base model
        for param in self.base_model.parameters():
            param.requires_grad = False

        # Get model dimensions
        config = self.base_model.config
        self.n_layer = config.n_layer
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.prefix_length = prefix_length

        # Trainable prefix parameters
        # Shape: (n_layers, 2, n_heads, prefix_length, head_dim)
        # 2 for key and value
        head_dim = self.n_embd // self.n_head

        self.prefix_embeddings = nn.Parameter(
            torch.randn(
                self.n_layer,
                2,  # key and value
                self.n_head,
                prefix_length,
                head_dim
            ) * 0.01
        )

        print(f"Prefix length: {prefix_length} tokens")
        self.print_trainable_parameters()

    def print_trainable_parameters(self):
        """Print trainable parameter count."""
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.parameters())

        print(f"Trainable parameters: {trainable:,} ({100*trainable/total:.2f}%)")
        print(f"Total parameters: {total:,}")

    def get_prompt(self, batch_size):
        """
        Get prefix key/value pairs for all layers.

        Args:
            batch_size: Batch size

        Returns:
            Prefix key-value pairs for each layer
        """
        # Expand prefix embeddings for batch
        prefix_kv = self.prefix_embeddings.unsqueeze(0).expand(
            batch_size, -1, -1, -1, -1, -1
        )

        return prefix_kv

    def forward(self, input_ids, attention_mask=None, labels=None):
        """
        Forward pass with prefix tuning.

        Args:
            input_ids: Input token IDs
            attention_mask: Attention mask
            labels: Labels for language modeling
        """
        batch_size = input_ids.shape[0]

        # Get prefix key-values
        prefix_kv = self.get_prompt(batch_size)

        # This is a simplified version - actual implementation would
        # inject prefix_kv into each attention layer
        # For demonstration, we'll use the standard forward pass

        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        return outputs


# Simpler version: Soft Prompt Tuning
class PromptTuning(nn.Module):
    """
    Prompt tuning: Add trainable embeddings as soft prompts.

    Simpler than prefix tuning - just prepend trainable embeddings.
    """

    def __init__(self, base_model_name='gpt2', n_prompt_tokens=20):
        """
        Args:
            base_model_name: Base model name
            n_prompt_tokens: Number of soft prompt tokens
        """
        super().__init__()

        # Load base model
        self.base_model = GPT2LMHeadModel.from_pretrained(base_model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(base_model_name)

        # Freeze base model
        for param in self.base_model.parameters():
            param.requires_grad = False

        # Get embedding dimension
        embed_dim = self.base_model.config.n_embd

        # Trainable soft prompt embeddings
        self.soft_prompt = nn.Parameter(
            torch.randn(n_prompt_tokens, embed_dim) * 0.01
        )

        self.n_prompt_tokens = n_prompt_tokens

        print(f"Soft prompt tokens: {n_prompt_tokens}")
        self.print_trainable_parameters()

    def print_trainable_parameters(self):
        """Print trainable parameters."""
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.parameters())

        print(f"Trainable parameters: {trainable:,} ({100*trainable/total:.2f}%)")
        print(f"Total parameters: {total:,}")

    def forward(self, input_ids, attention_mask=None, labels=None):
        """
        Forward with soft prompts prepended.

        Args:
            input_ids: Input token IDs (batch, seq_len)
            attention_mask: Attention mask
            labels: Labels for LM loss
        """
        batch_size = input_ids.shape[0]

        # Get input embeddings
        input_embeds = self.base_model.transformer.wte(input_ids)

        # Expand soft prompt for batch
        soft_prompt_batch = self.soft_prompt.unsqueeze(0).expand(
            batch_size, -1, -1
        )

        # Concatenate soft prompt with input embeddings
        inputs_embeds = torch.cat([soft_prompt_batch, input_embeds], dim=1)

        # Extend attention mask for soft prompts
        if attention_mask is not None:
            prefix_mask = torch.ones(
                batch_size, self.n_prompt_tokens,
                device=attention_mask.device,
                dtype=attention_mask.dtype
            )
            attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)

        # Adjust labels if provided
        if labels is not None:
            prefix_labels = torch.full(
                (batch_size, self.n_prompt_tokens),
                -100,  # Ignore index
                device=labels.device,
                dtype=labels.dtype
            )
            labels = torch.cat([prefix_labels, labels], dim=1)

        # Forward pass
        outputs = self.base_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels
        )

        return outputs


# Example usage
print("\nPrompt Tuning:")
prompt_model = PromptTuning('gpt2', n_prompt_tokens=20)

Method Comparison

python
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Create comparison table
comparison_data = {
    'Method': ['Full Fine-Tuning', 'Adapters', 'Prefix Tuning', 'Prompt Tuning', 'LoRA'],
    'Trainable Params': ['100%', '1-5%', '0.1-1%', '0.01-0.1%', '0.1-1%'],
    'Memory (Train)': ['High', 'Medium', 'Low', 'Very Low', 'Low'],
    'Task Performance': ['Best', 'Very Good', 'Good', 'Good', 'Very Good'],
    'Training Speed': ['Slow', 'Medium', 'Fast', 'Fast', 'Fast'],
    'Inference Overhead': ['None', 'Small', 'None', 'None', 'None (merged)'],
    'Multi-task': ['Poor', 'Excellent', 'Good', 'Good', 'Good'],
    'Best For': [
        'High-resource tasks',
        'Multi-task scenarios',
        'Low-resource tasks',
        'Extremely low data',
        'General purpose PEFT'
    ]
}

df = pd.DataFrame(comparison_data)
print(df.to_string(index=False))

# Visualize parameter efficiency
methods = ['Full FT', 'Adapters', 'Prefix', 'Prompt', 'LoRA']
params_pct = [100, 3, 0.5, 0.05, 0.5]
performance = [100, 95, 85, 80, 95]  # Approximate relative performance

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Parameters comparison
ax1.barh(methods, params_pct, color='steelblue')
ax1.set_xlabel('Trainable Parameters (%)', fontsize=12)
ax1.set_title('Parameter Efficiency', fontsize=14, fontweight='bold')
ax1.set_xscale('log')
ax1.grid(axis='x', alpha=0.3)

# Performance vs efficiency
ax2.scatter(params_pct, performance, s=200, alpha=0.6, c=range(len(methods)), cmap='viridis')
for i, method in enumerate(methods):
    ax2.annotate(method, (params_pct[i], performance[i]),
                xytext=(10, 5), textcoords='offset points')
ax2.set_xlabel('Trainable Parameters (%, log scale)', fontsize=12)
ax2.set_ylabel('Relative Performance', fontsize=12)
ax2.set_title('Efficiency vs Performance Tradeoff', fontsize=14, fontweight='bold')
ax2.set_xscale('log')
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.show()

Choosing a Method:

Use Full Fine-Tuning when:

  • You have 10,000+ high-quality examples
  • Maximum performance is critical
  • You have sufficient compute resources
  • Single-task deployment

Use Adapters when:

  • You need multiple task-specific models
  • Moderate data (1,000-10,000 examples)
  • Want to preserve base model
  • Multi-tenant scenarios

Use Prefix/Prompt Tuning when:

  • Very limited data (< 1,000 examples)
  • Minimal compute budget
  • Fast iteration needed
  • Task can be solved with guidance

Use LoRA when:

  • General-purpose PEFT needed
  • Balance of efficiency and performance
  • Want to merge back into base model
  • 1,000-10,000 examples

Summary

Fine-tuning methods offer different tradeoffs:

  1. Full fine-tuning: Maximum adaptation, maximum cost
  2. Adapters: Modular, efficient, excellent for multi-task
  3. Prefix/Prompt tuning: Extremely efficient, good for low-resource
  4. LoRA: Best balance of efficiency and performance (next lesson!)

The trend is toward parameter-efficient methods that achieve comparable results with 100x fewer trainable parameters.