Transfer Learning for LLMs
Transfer learning is the foundation of modern NLP. Instead of training models from scratch, we leverage pre-trained language models and adapt them to specific tasks through fine-tuning.
What is Transfer Learning?
Transfer learning involves two stages:
- Pre-training: Learn general language understanding from massive unlabeled text
- Fine-tuning: Adapt the model to specific tasks with smaller labeled datasets
Key Insight:
Pre-training captures general linguistic knowledge (grammar, facts, reasoning). Fine-tuning specializes this knowledge for specific applications (classification, QA, generation).
This paradigm enables models to perform well on tasks with limited data by leveraging broad knowledge from pre-training.
Pre-Training vs Fine-Tuning
Pre-Training Phase
Pre-training uses self-supervised learning on billions of tokens:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.utils.data import Dataset, DataLoader
class PreTrainingDataset(Dataset):
"""Dataset for causal language modeling pre-training."""
def __init__(self, texts, tokenizer, max_length=512):
"""
Args:
texts: List of text strings
tokenizer: Tokenizer instance
max_length: Maximum sequence length
"""
self.tokenizer = tokenizer
self.max_length = max_length
self.examples = []
for text in texts:
tokens = tokenizer.encode(text, truncation=True, max_length=max_length)
if len(tokens) > 1: # Need at least 2 tokens (input + target)
self.examples.append(tokens)
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
tokens = self.examples[idx]
# For causal LM: input is tokens[:-1], target is tokens[1:]
return {
'input_ids': torch.tensor(tokens[:-1]),
'labels': torch.tensor(tokens[1:])
}
def pretrain_language_model(model, train_loader, epochs=1, lr=1e-4):
"""
Pre-train a language model with causal language modeling objective.
Args:
model: The language model
train_loader: Training data loader
epochs: Number of training epochs
lr: Learning rate
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
model.train()
for epoch in range(epochs):
total_loss = 0
for batch_idx, batch in enumerate(train_loader):
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].to(device)
# Forward pass
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if (batch_idx + 1) % 100 == 0:
avg_loss = total_loss / (batch_idx + 1)
print(f"Epoch {epoch+1}, Batch {batch_idx+1}, Loss: {avg_loss:.4f}")
avg_loss = total_loss / len(train_loader)
print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")
# Example pre-training setup
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
# Simulate pre-training corpus (in practice: billions of tokens)
pretrain_texts = [
"The quick brown fox jumps over the lazy dog.",
"Machine learning is a subset of artificial intelligence.",
"Natural language processing enables computers to understand human language.",
# ... millions more documents
]
# Create dataset and loader
pretrain_dataset = PreTrainingDataset(pretrain_texts, tokenizer)
pretrain_loader = DataLoader(pretrain_dataset, batch_size=8, shuffle=True)
# Initialize model
model = GPT2LMHeadModel.from_pretrained('gpt2')
# Pre-train (simplified example)
# pretrain_language_model(model, pretrain_loader, epochs=3)
Pre-Training is Expensive:
Real pre-training requires:
- Compute: Thousands of GPU-days (e.g., GPT-3: ~355 GPU-years)
- Data: Hundreds of GB to TB of text (CommonCrawl, books, Wikipedia)
- Cost: Millions of dollars for large models
Most practitioners use existing pre-trained models rather than pre-training from scratch.
Fine-Tuning Phase
Fine-tuning adapts the pre-trained model to specific tasks:
from torch.utils.data import Dataset
import torch.nn.functional as F
class FineTuningDataset(Dataset):
"""Dataset for task-specific fine-tuning."""
def __init__(self, examples, tokenizer, max_length=512):
"""
Args:
examples: List of (input_text, target_text) tuples
tokenizer: Tokenizer instance
max_length: Maximum sequence length
"""
self.tokenizer = tokenizer
self.max_length = max_length
self.examples = examples
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
input_text, target_text = self.examples[idx]
# Tokenize input and target
input_encoding = self.tokenizer(
input_text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
target_encoding = self.tokenizer(
target_text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': input_encoding['input_ids'].squeeze(),
'attention_mask': input_encoding['attention_mask'].squeeze(),
'labels': target_encoding['input_ids'].squeeze()
}
def finetune_model(model, train_loader, val_loader, epochs=3, lr=5e-5):
"""
Fine-tune a pre-trained model on a specific task.
Args:
model: Pre-trained model
train_loader: Training data loader
val_loader: Validation data loader
epochs: Number of fine-tuning epochs
lr: Learning rate (typically lower than pre-training)
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
best_val_loss = float('inf')
for epoch in range(epochs):
# Training
model.train()
train_loss = 0
for batch in train_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
loss = outputs.loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
avg_train_loss = train_loss / len(train_loader)
# Validation
model.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
val_loss += outputs.loss.item()
avg_val_loss = val_loss / len(val_loader)
print(f"Epoch {epoch+1}/{epochs}")
print(f" Train Loss: {avg_train_loss:.4f}")
print(f" Val Loss: {avg_val_loss:.4f}")
# Save best model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
torch.save(model.state_dict(), 'best_finetuned_model.pt')
print(" Saved best model!")
# Example: Fine-tune for summarization
finetune_examples = [
(
"Summarize: The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower.",
"The Eiffel Tower is an iron tower in Paris designed by Gustave Eiffel's company."
),
(
"Summarize: Machine learning is a method of data analysis that automates analytical model building. It is a branch of artificial intelligence based on the idea that systems can learn from data, identify patterns and make decisions with minimal human intervention.",
"Machine learning is an AI method that enables systems to learn from data and make decisions automatically."
),
# ... more examples
]
# Create fine-tuning dataset
finetune_dataset = FineTuningDataset(finetune_examples, tokenizer)
finetune_loader = DataLoader(finetune_dataset, batch_size=4, shuffle=True)
# Fine-tune the model
# finetune_model(model, finetune_loader, val_loader, epochs=3, lr=5e-5)
Comparison: Pre-Training vs Fine-Tuning
import pandas as pd
comparison = pd.DataFrame({
'Aspect': [
'Objective',
'Data',
'Data Size',
'Supervision',
'Duration',
'Learning Rate',
'Compute Cost',
'Who Does It'
],
'Pre-Training': [
'Learn general language understanding',
'Raw text (web, books, etc.)',
'Billions of tokens',
'Self-supervised (no labels)',
'Weeks to months',
'1e-4 to 1e-3',
'Very high ($M)',
'Large organizations'
],
'Fine-Tuning': [
'Adapt to specific task',
'Task-specific labeled data',
'Thousands to millions of tokens',
'Supervised (labeled)',
'Hours to days',
'1e-5 to 5e-5',
'Low to moderate ($)',
'Individual researchers/companies'
]
})
print(comparison.to_string(index=False))
Key Differences:
- Scale: Pre-training uses 1000x more data and compute
- Learning Rate: Fine-tuning uses ~10x lower learning rate to preserve pre-trained knowledge
- Objective: Pre-training learns general patterns; fine-tuning specializes
- Accessibility: Pre-training requires massive resources; fine-tuning is accessible to most researchers
When to Fine-Tune
Decision framework for when fine-tuning makes sense:
def should_finetune(
task_specific_data_size,
task_similarity_to_pretraining,
available_compute,
performance_requirements
):
"""
Decision framework for whether to fine-tune.
Args:
task_specific_data_size: Amount of labeled data (examples)
task_similarity_to_pretraining: How similar task is to pre-training (0-1)
available_compute: Available GPU hours
performance_requirements: Required task performance (0-1)
Returns:
Recommendation: dict with decision and reasoning
"""
recommendations = []
# Check data size
if task_specific_data_size < 100:
recommendations.append({
'decision': 'Few-shot prompting or retrieval',
'reason': 'Too few examples for effective fine-tuning',
'confidence': 'high'
})
elif task_specific_data_size < 1000:
recommendations.append({
'decision': 'Parameter-efficient fine-tuning (LoRA/Adapters)',
'reason': 'Moderate data; efficient methods prevent overfitting',
'confidence': 'medium'
})
else:
recommendations.append({
'decision': 'Full fine-tuning possible',
'reason': 'Sufficient data for full model adaptation',
'confidence': 'high'
})
# Check task similarity
if task_similarity_to_pretraining > 0.8:
recommendations.append({
'decision': 'Prompting may be sufficient',
'reason': 'Task very similar to pre-training; model already has needed knowledge',
'confidence': 'medium'
})
elif task_similarity_to_pretraining < 0.3:
recommendations.append({
'decision': 'Fine-tuning strongly recommended',
'reason': 'Task very different from pre-training; adaptation needed',
'confidence': 'high'
})
# Check compute constraints
if available_compute < 10: # GPU hours
recommendations.append({
'decision': 'Use parameter-efficient methods or smaller models',
'reason': 'Limited compute budget',
'confidence': 'high'
})
# Check performance requirements
if performance_requirements > 0.9:
recommendations.append({
'decision': 'Fine-tuning required',
'reason': 'High performance requirements need task-specific adaptation',
'confidence': 'high'
})
return recommendations
# Example scenarios
scenarios = [
{
'name': 'Medical diagnosis from reports',
'data_size': 5000,
'similarity': 0.4,
'compute': 50,
'performance': 0.95
},
{
'name': 'General question answering',
'data_size': 100,
'similarity': 0.9,
'compute': 5,
'performance': 0.7
},
{
'name': 'Custom chatbot for company',
'data_size': 2000,
'similarity': 0.6,
'compute': 20,
'performance': 0.85
}
]
for scenario in scenarios:
print(f"\nScenario: {scenario['name']}")
print(f"Data: {scenario['data_size']} examples")
recommendations = should_finetune(
scenario['data_size'],
scenario['similarity'],
scenario['compute'],
scenario['performance']
)
print("\nRecommendations:")
for i, rec in enumerate(recommendations, 1):
print(f"{i}. {rec['decision']}")
print(f" Reason: {rec['reason']}")
print(f" Confidence: {rec['confidence']}")
When to Fine-Tune:
Fine-tune when:
- You have 1000+ labeled examples
- Task differs significantly from general language modeling
- High performance is critical
- Domain-specific knowledge is required (medical, legal, etc.)
Don't fine-tune when:
- You have < 100 examples (use prompting instead)
- Task is general and well-covered in pre-training
- You lack compute resources
- Prompting achieves acceptable results
Fine-Tuning Approaches Spectrum
import matplotlib.pyplot as plt
import numpy as np
# Create visualization of fine-tuning spectrum
fig, ax = plt.subplots(figsize=(12, 6))
approaches = [
'Zero-Shot\nPrompting',
'Few-Shot\nPrompting',
'Adapter\nTuning',
'LoRA',
'Prefix\nTuning',
'Full\nFine-Tuning',
'Continued\nPre-Training'
]
# Metrics (normalized 0-1)
trainable_params = [0, 0, 0.05, 0.1, 0.15, 1.0, 1.0]
data_required = [0, 0.05, 0.3, 0.4, 0.35, 0.7, 0.9]
compute_cost = [0.01, 0.01, 0.2, 0.25, 0.3, 1.0, 1.2]
task_adaptation = [0.3, 0.5, 0.7, 0.75, 0.7, 0.95, 1.0]
x = np.arange(len(approaches))
width = 0.2
ax.bar(x - 1.5*width, trainable_params, width, label='Trainable Params', alpha=0.8)
ax.bar(x - 0.5*width, data_required, width, label='Data Required', alpha=0.8)
ax.bar(x + 0.5*width, compute_cost, width, label='Compute Cost', alpha=0.8)
ax.bar(x + 1.5*width, task_adaptation, width, label='Task Adaptation', alpha=0.8)
ax.set_xlabel('Approach', fontsize=12)
ax.set_ylabel('Relative Scale (0-1)', fontsize=12)
ax.set_title('Fine-Tuning Approaches: Tradeoffs', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(approaches)
ax.legend()
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()
Transfer Learning Best Practices
class TransferLearningPipeline:
"""
Best practices pipeline for transfer learning with LLMs.
"""
def __init__(self, base_model_name, task_type):
"""
Args:
base_model_name: Pre-trained model to use
task_type: Type of task (classification, generation, etc.)
"""
self.base_model_name = base_model_name
self.task_type = task_type
def select_base_model(self):
"""Select appropriate base model for task."""
recommendations = {
'classification': ['bert-base', 'roberta-base'],
'generation': ['gpt2', 't5-base'],
'seq2seq': ['t5-base', 'bart-base'],
'embedding': ['sentence-transformers/all-mpnet-base-v2']
}
print(f"Task: {self.task_type}")
print(f"Recommended models: {recommendations.get(self.task_type, ['gpt2'])}")
def prepare_data(self, examples, train_split=0.8):
"""
Prepare data with proper train/val split.
Best practices:
- Stratified split for classification
- Random split for generation
- Minimum validation size: 500 examples or 20%
"""
import random
random.shuffle(examples)
split_idx = int(len(examples) * train_split)
train_data = examples[:split_idx]
val_data = examples[split_idx:]
print(f"Train examples: {len(train_data)}")
print(f"Val examples: {len(val_data)}")
# Ensure minimum validation size
min_val_size = max(500, int(len(examples) * 0.2))
if len(val_data) < min_val_size:
print(f"Warning: Validation set too small (< {min_val_size})")
return train_data, val_data
def choose_learning_rate(self, model_size, data_size):
"""
Choose appropriate learning rate.
Rules of thumb:
- Smaller models: 5e-5
- Larger models (>1B params): 1e-5
- Less data: lower LR
"""
if model_size > 1e9: # > 1B parameters
base_lr = 1e-5
elif model_size > 100e6: # > 100M parameters
base_lr = 2e-5
else:
base_lr = 5e-5
# Adjust for data size
if data_size < 1000:
base_lr *= 0.5
print(f"Recommended learning rate: {base_lr:.0e}")
return base_lr
def set_training_hyperparameters(self, data_size):
"""
Set training hyperparameters based on data size.
"""
if data_size < 1000:
epochs = 10
batch_size = 8
warmup_ratio = 0.1
elif data_size < 10000:
epochs = 5
batch_size = 16
warmup_ratio = 0.05
else:
epochs = 3
batch_size = 32
warmup_ratio = 0.02
config = {
'epochs': epochs,
'batch_size': batch_size,
'warmup_ratio': warmup_ratio,
'weight_decay': 0.01,
'gradient_accumulation_steps': 1
}
print("Recommended hyperparameters:")
for key, value in config.items():
print(f" {key}: {value}")
return config
# Example usage
pipeline = TransferLearningPipeline('gpt2', 'generation')
pipeline.select_base_model()
# Simulate data
examples = [('input', 'target')] * 5000
train_data, val_data = pipeline.prepare_data(examples)
model_size = 124e6 # GPT2 small
lr = pipeline.choose_learning_rate(model_size, len(train_data))
config = pipeline.set_training_hyperparameters(len(train_data))
Summary
Transfer learning revolutionized NLP by making powerful models accessible:
- Pre-training learns general language understanding (expensive, done once)
- Fine-tuning adapts to specific tasks (affordable, task-specific)
- When to fine-tune: 1000+ examples, task-specific needs, performance requirements
- Spectrum of approaches: From zero-shot prompting to full fine-tuning
Modern LLM development is almost always transfer learning: start with a pre-trained model and adapt it to your needs.