Fine-Tuning Methods Explained
Fine-tuning adapts pre-trained models to specific tasks, but there are multiple approaches with different tradeoffs. Let's explore full fine-tuning, parameter-efficient methods, and when to use each.
The Fine-Tuning Spectrum
Modern fine-tuning methods range from updating all parameters to updating less than 1%:
Fine-Tuning Methods by Parameter Count:
- Full Fine-Tuning: Update all parameters (100%)
- Adapters: Add small trainable layers (~1-5%)
- Prefix/Prompt Tuning: Add trainable tokens (~0.1-1%)
- LoRA: Low-rank decomposition (~0.1-1%)
- BitFit: Only bias terms (~0.1%)
Each method trades off between adaptation quality and computational efficiency.
1. Full Fine-Tuning
Full fine-tuning updates all model parameters for maximum task adaptation.
Implementation
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AdamW
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
class FullFineTuner:
"""
Full fine-tuning: Update all model parameters.
Pros: Maximum task adaptation, simple to implement
Cons: High memory, requires large dataset, risk of catastrophic forgetting
"""
def __init__(self, model_name='gpt2', learning_rate=5e-5):
"""
Args:
model_name: Pre-trained model name
learning_rate: Learning rate for fine-tuning
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load model and tokenizer
self.model = GPT2LMHeadModel.from_pretrained(model_name)
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.to(self.device)
self.learning_rate = learning_rate
# Print trainable parameters
self.print_trainable_parameters()
def print_trainable_parameters(self):
"""Print number of trainable parameters."""
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in self.model.parameters())
print(f"Trainable parameters: {trainable_params:,}")
print(f"All parameters: {all_params:,}")
print(f"Trainable: {100 * trainable_params / all_params:.2f}%")
def train(self, train_loader, val_loader, epochs=3):
"""
Train the model on task-specific data.
Args:
train_loader: Training data loader
val_loader: Validation data loader
epochs: Number of training epochs
"""
# Optimizer for all parameters
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate)
best_val_loss = float('inf')
for epoch in range(epochs):
# Training phase
self.model.train()
train_loss = 0
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
for batch in progress_bar:
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].to(self.device)
# Forward pass
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
loss = outputs.loss
# Backward pass
optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
train_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
avg_train_loss = train_loss / len(train_loader)
# Validation phase
val_loss = self.validate(val_loader)
print(f"\nEpoch {epoch+1}/{epochs}")
print(f" Average training loss: {avg_train_loss:.4f}")
print(f" Validation loss: {val_loss:.4f}")
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
self.save_model('best_full_finetuned_model')
print(" Saved best model!")
def validate(self, val_loader):
"""Validate the model."""
self.model.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].to(self.device)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
val_loss += outputs.loss.item()
return val_loss / len(val_loader)
def save_model(self, path):
"""Save the fine-tuned model."""
self.model.save_pretrained(path)
self.tokenizer.save_pretrained(path)
# Example usage
# tuner = FullFineTuner('gpt2', learning_rate=5e-5)
# tuner.train(train_loader, val_loader, epochs=3)
Full Fine-Tuning Challenges:
- Memory: Needs memory for all parameters + gradients + optimizer states (3x model size)
- Data: Requires thousands of examples to avoid overfitting
- Catastrophic forgetting: Can lose pre-trained knowledge
- Cost: Expensive for large models (7B+ parameters)
For GPT-3 (175B params), full fine-tuning requires ~1.4TB of GPU memory!
2. Adapter Layers
Adapters insert small bottleneck layers between transformer layers, freezing the original model.
Architecture
Original Layer → Adapter → Next Layer
Adapter structure:
Input (d) → Down-project (r) → Non-linearity → Up-project (d) → Residual
where r << d (e.g., r=64, d=768)
Implementation
class AdapterLayer(nn.Module):
"""
Adapter layer: bottleneck architecture for parameter-efficient tuning.
Reduces dimension to bottleneck, applies non-linearity, then expands back.
"""
def __init__(self, input_dim, bottleneck_dim=64, dropout=0.1):
"""
Args:
input_dim: Input dimension (model hidden size)
bottleneck_dim: Bottleneck dimension (much smaller)
dropout: Dropout probability
"""
super().__init__()
# Down-projection
self.down_project = nn.Linear(input_dim, bottleneck_dim)
# Non-linearity
self.activation = nn.GELU()
# Up-projection
self.up_project = nn.Linear(bottleneck_dim, input_dim)
# Dropout
self.dropout = nn.Dropout(dropout)
# Initialize near-identity at start
nn.init.zeros_(self.up_project.weight)
nn.init.zeros_(self.up_project.bias)
def forward(self, x):
"""
Args:
x: Input tensor (batch, seq_len, input_dim)
Returns:
Output with residual connection
"""
# Residual connection
residual = x
# Adapter transformation
x = self.down_project(x)
x = self.activation(x)
x = self.up_project(x)
x = self.dropout(x)
# Add residual
return x + residual
class GPT2WithAdapters(nn.Module):
"""
GPT-2 with adapter layers inserted after each transformer block.
"""
def __init__(self, base_model_name='gpt2', bottleneck_dim=64):
"""
Args:
base_model_name: Base model to adapt
bottleneck_dim: Adapter bottleneck dimension
"""
super().__init__()
# Load base model
self.base_model = GPT2LMHeadModel.from_pretrained(base_model_name)
# Freeze all base model parameters
for param in self.base_model.parameters():
param.requires_grad = False
# Get model dimension
config = self.base_model.config
hidden_size = config.n_embd
# Add adapters after each transformer block
self.adapters = nn.ModuleList([
AdapterLayer(hidden_size, bottleneck_dim)
for _ in range(config.n_layer)
])
print(f"Added {len(self.adapters)} adapter layers")
self.print_trainable_parameters()
def print_trainable_parameters(self):
"""Print trainable parameters."""
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
print(f"Trainable parameters: {trainable:,} ({100*trainable/total:.2f}%)")
print(f"Total parameters: {total:,}")
def forward(self, input_ids, attention_mask=None, labels=None):
"""
Forward pass with adapters.
Args:
input_ids: Input token IDs
attention_mask: Attention mask
labels: Labels for language modeling
"""
# Get base model hidden states
outputs = self.base_model.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)
# Apply adapters to each layer's output
hidden_states = outputs.hidden_states
# Process through adapters
for i, adapter in enumerate(self.adapters):
# Get layer output (skip input embedding)
layer_output = hidden_states[i + 1]
# Apply adapter
adapted_output = adapter(layer_output)
# Use adapted output for next layer
if i < len(self.adapters) - 1:
hidden_states = (
hidden_states[:i+2] +
(adapted_output,) +
hidden_states[i+2:]
)
else:
# Last layer
final_hidden = adapted_output
# Apply LM head
logits = self.base_model.lm_head(final_hidden)
# Calculate loss if labels provided
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
return {'loss': loss, 'logits': logits}
# Example: Compare parameter counts
print("Full Fine-Tuning:")
full_model = GPT2LMHeadModel.from_pretrained('gpt2')
total_params = sum(p.numel() for p in full_model.parameters())
print(f"Trainable parameters: {total_params:,} (100%)\n")
print("Adapter Tuning:")
adapter_model = GPT2WithAdapters('gpt2', bottleneck_dim=64)
Adapter Advantages:
- Efficiency: Only ~1-5% parameters trainable
- Modularity: Can swap adapters for different tasks
- No forgetting: Base model frozen, preserves pre-trained knowledge
- Storage: Store multiple task adapters efficiently
Perfect for multi-task scenarios where you need specialized models for different domains.
3. Prefix Tuning / Prompt Tuning
Add trainable continuous prompts (virtual tokens) while freezing the model.
Implementation
class PrefixTuning(nn.Module):
"""
Prefix tuning: Add trainable prefix embeddings before input.
Instead of adding adapter layers, we add trainable "virtual tokens"
that guide the model's behavior.
"""
def __init__(self, base_model_name='gpt2', prefix_length=20):
"""
Args:
base_model_name: Base model name
prefix_length: Number of prefix tokens to add
"""
super().__init__()
# Load base model
self.base_model = GPT2LMHeadModel.from_pretrained(base_model_name)
# Freeze base model
for param in self.base_model.parameters():
param.requires_grad = False
# Get model dimensions
config = self.base_model.config
self.n_layer = config.n_layer
self.n_head = config.n_head
self.n_embd = config.n_embd
self.prefix_length = prefix_length
# Trainable prefix parameters
# Shape: (n_layers, 2, n_heads, prefix_length, head_dim)
# 2 for key and value
head_dim = self.n_embd // self.n_head
self.prefix_embeddings = nn.Parameter(
torch.randn(
self.n_layer,
2, # key and value
self.n_head,
prefix_length,
head_dim
) * 0.01
)
print(f"Prefix length: {prefix_length} tokens")
self.print_trainable_parameters()
def print_trainable_parameters(self):
"""Print trainable parameter count."""
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
print(f"Trainable parameters: {trainable:,} ({100*trainable/total:.2f}%)")
print(f"Total parameters: {total:,}")
def get_prompt(self, batch_size):
"""
Get prefix key/value pairs for all layers.
Args:
batch_size: Batch size
Returns:
Prefix key-value pairs for each layer
"""
# Expand prefix embeddings for batch
prefix_kv = self.prefix_embeddings.unsqueeze(0).expand(
batch_size, -1, -1, -1, -1, -1
)
return prefix_kv
def forward(self, input_ids, attention_mask=None, labels=None):
"""
Forward pass with prefix tuning.
Args:
input_ids: Input token IDs
attention_mask: Attention mask
labels: Labels for language modeling
"""
batch_size = input_ids.shape[0]
# Get prefix key-values
prefix_kv = self.get_prompt(batch_size)
# This is a simplified version - actual implementation would
# inject prefix_kv into each attention layer
# For demonstration, we'll use the standard forward pass
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
return outputs
# Simpler version: Soft Prompt Tuning
class PromptTuning(nn.Module):
"""
Prompt tuning: Add trainable embeddings as soft prompts.
Simpler than prefix tuning - just prepend trainable embeddings.
"""
def __init__(self, base_model_name='gpt2', n_prompt_tokens=20):
"""
Args:
base_model_name: Base model name
n_prompt_tokens: Number of soft prompt tokens
"""
super().__init__()
# Load base model
self.base_model = GPT2LMHeadModel.from_pretrained(base_model_name)
self.tokenizer = GPT2Tokenizer.from_pretrained(base_model_name)
# Freeze base model
for param in self.base_model.parameters():
param.requires_grad = False
# Get embedding dimension
embed_dim = self.base_model.config.n_embd
# Trainable soft prompt embeddings
self.soft_prompt = nn.Parameter(
torch.randn(n_prompt_tokens, embed_dim) * 0.01
)
self.n_prompt_tokens = n_prompt_tokens
print(f"Soft prompt tokens: {n_prompt_tokens}")
self.print_trainable_parameters()
def print_trainable_parameters(self):
"""Print trainable parameters."""
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
print(f"Trainable parameters: {trainable:,} ({100*trainable/total:.2f}%)")
print(f"Total parameters: {total:,}")
def forward(self, input_ids, attention_mask=None, labels=None):
"""
Forward with soft prompts prepended.
Args:
input_ids: Input token IDs (batch, seq_len)
attention_mask: Attention mask
labels: Labels for LM loss
"""
batch_size = input_ids.shape[0]
# Get input embeddings
input_embeds = self.base_model.transformer.wte(input_ids)
# Expand soft prompt for batch
soft_prompt_batch = self.soft_prompt.unsqueeze(0).expand(
batch_size, -1, -1
)
# Concatenate soft prompt with input embeddings
inputs_embeds = torch.cat([soft_prompt_batch, input_embeds], dim=1)
# Extend attention mask for soft prompts
if attention_mask is not None:
prefix_mask = torch.ones(
batch_size, self.n_prompt_tokens,
device=attention_mask.device,
dtype=attention_mask.dtype
)
attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
# Adjust labels if provided
if labels is not None:
prefix_labels = torch.full(
(batch_size, self.n_prompt_tokens),
-100, # Ignore index
device=labels.device,
dtype=labels.dtype
)
labels = torch.cat([prefix_labels, labels], dim=1)
# Forward pass
outputs = self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels
)
return outputs
# Example usage
print("\nPrompt Tuning:")
prompt_model = PromptTuning('gpt2', n_prompt_tokens=20)
Method Comparison
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
# Create comparison table
comparison_data = {
'Method': ['Full Fine-Tuning', 'Adapters', 'Prefix Tuning', 'Prompt Tuning', 'LoRA'],
'Trainable Params': ['100%', '1-5%', '0.1-1%', '0.01-0.1%', '0.1-1%'],
'Memory (Train)': ['High', 'Medium', 'Low', 'Very Low', 'Low'],
'Task Performance': ['Best', 'Very Good', 'Good', 'Good', 'Very Good'],
'Training Speed': ['Slow', 'Medium', 'Fast', 'Fast', 'Fast'],
'Inference Overhead': ['None', 'Small', 'None', 'None', 'None (merged)'],
'Multi-task': ['Poor', 'Excellent', 'Good', 'Good', 'Good'],
'Best For': [
'High-resource tasks',
'Multi-task scenarios',
'Low-resource tasks',
'Extremely low data',
'General purpose PEFT'
]
}
df = pd.DataFrame(comparison_data)
print(df.to_string(index=False))
# Visualize parameter efficiency
methods = ['Full FT', 'Adapters', 'Prefix', 'Prompt', 'LoRA']
params_pct = [100, 3, 0.5, 0.05, 0.5]
performance = [100, 95, 85, 80, 95] # Approximate relative performance
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# Parameters comparison
ax1.barh(methods, params_pct, color='steelblue')
ax1.set_xlabel('Trainable Parameters (%)', fontsize=12)
ax1.set_title('Parameter Efficiency', fontsize=14, fontweight='bold')
ax1.set_xscale('log')
ax1.grid(axis='x', alpha=0.3)
# Performance vs efficiency
ax2.scatter(params_pct, performance, s=200, alpha=0.6, c=range(len(methods)), cmap='viridis')
for i, method in enumerate(methods):
ax2.annotate(method, (params_pct[i], performance[i]),
xytext=(10, 5), textcoords='offset points')
ax2.set_xlabel('Trainable Parameters (%, log scale)', fontsize=12)
ax2.set_ylabel('Relative Performance', fontsize=12)
ax2.set_title('Efficiency vs Performance Tradeoff', fontsize=14, fontweight='bold')
ax2.set_xscale('log')
ax2.grid(alpha=0.3)
plt.tight_layout()
plt.show()
Choosing a Method:
Use Full Fine-Tuning when:
- You have 10,000+ high-quality examples
- Maximum performance is critical
- You have sufficient compute resources
- Single-task deployment
Use Adapters when:
- You need multiple task-specific models
- Moderate data (1,000-10,000 examples)
- Want to preserve base model
- Multi-tenant scenarios
Use Prefix/Prompt Tuning when:
- Very limited data (< 1,000 examples)
- Minimal compute budget
- Fast iteration needed
- Task can be solved with guidance
Use LoRA when:
- General-purpose PEFT needed
- Balance of efficiency and performance
- Want to merge back into base model
- 1,000-10,000 examples
Summary
Fine-tuning methods offer different tradeoffs:
- Full fine-tuning: Maximum adaptation, maximum cost
- Adapters: Modular, efficient, excellent for multi-task
- Prefix/Prompt tuning: Extremely efficient, good for low-resource
- LoRA: Best balance of efficiency and performance (next lesson!)
The trend is toward parameter-efficient methods that achieve comparable results with 100x fewer trainable parameters.