Back
advanced
Fine-Tuning Techniques

Transfer Learning for LLMs

Understanding transfer learning in large language models: pre-training vs fine-tuning, when to fine-tune, and adapting pre-trained models to specific tasks.

18 min read· Transfer Learning· Fine-Tuning· Pre-training· LLMs

Transfer Learning for LLMs

Transfer learning is the foundation of modern NLP. Instead of training models from scratch, we leverage pre-trained language models and adapt them to specific tasks through fine-tuning.

What is Transfer Learning?

Transfer learning involves two stages:

  1. Pre-training: Learn general language understanding from massive unlabeled text
  2. Fine-tuning: Adapt the model to specific tasks with smaller labeled datasets

Key Insight:

Pre-training captures general linguistic knowledge (grammar, facts, reasoning). Fine-tuning specializes this knowledge for specific applications (classification, QA, generation).

This paradigm enables models to perform well on tasks with limited data by leveraging broad knowledge from pre-training.

Pre-Training vs Fine-Tuning

Pre-Training Phase

Pre-training uses self-supervised learning on billions of tokens:

python
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.utils.data import Dataset, DataLoader

class PreTrainingDataset(Dataset):
    """Dataset for causal language modeling pre-training."""

    def __init__(self, texts, tokenizer, max_length=512):
        """
        Args:
            texts: List of text strings
            tokenizer: Tokenizer instance
            max_length: Maximum sequence length
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.examples = []

        for text in texts:
            tokens = tokenizer.encode(text, truncation=True, max_length=max_length)
            if len(tokens) > 1:  # Need at least 2 tokens (input + target)
                self.examples.append(tokens)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        tokens = self.examples[idx]
        # For causal LM: input is tokens[:-1], target is tokens[1:]
        return {
            'input_ids': torch.tensor(tokens[:-1]),
            'labels': torch.tensor(tokens[1:])
        }


def pretrain_language_model(model, train_loader, epochs=1, lr=1e-4):
    """
    Pre-train a language model with causal language modeling objective.

    Args:
        model: The language model
        train_loader: Training data loader
        epochs: Number of training epochs
        lr: Learning rate
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    model.train()

    for epoch in range(epochs):
        total_loss = 0

        for batch_idx, batch in enumerate(train_loader):
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            outputs = model(input_ids=input_ids, labels=labels)
            loss = outputs.loss

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            if (batch_idx + 1) % 100 == 0:
                avg_loss = total_loss / (batch_idx + 1)
                print(f"Epoch {epoch+1}, Batch {batch_idx+1}, Loss: {avg_loss:.4f}")

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")


# Example pre-training setup
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# Simulate pre-training corpus (in practice: billions of tokens)
pretrain_texts = [
    "The quick brown fox jumps over the lazy dog.",
    "Machine learning is a subset of artificial intelligence.",
    "Natural language processing enables computers to understand human language.",
    # ... millions more documents
]

# Create dataset and loader
pretrain_dataset = PreTrainingDataset(pretrain_texts, tokenizer)
pretrain_loader = DataLoader(pretrain_dataset, batch_size=8, shuffle=True)

# Initialize model
model = GPT2LMHeadModel.from_pretrained('gpt2')

# Pre-train (simplified example)
# pretrain_language_model(model, pretrain_loader, epochs=3)

Pre-Training is Expensive:

Real pre-training requires:

  • Compute: Thousands of GPU-days (e.g., GPT-3: ~355 GPU-years)
  • Data: Hundreds of GB to TB of text (CommonCrawl, books, Wikipedia)
  • Cost: Millions of dollars for large models

Most practitioners use existing pre-trained models rather than pre-training from scratch.

Fine-Tuning Phase

Fine-tuning adapts the pre-trained model to specific tasks:

python
from torch.utils.data import Dataset
import torch.nn.functional as F

class FineTuningDataset(Dataset):
    """Dataset for task-specific fine-tuning."""

    def __init__(self, examples, tokenizer, max_length=512):
        """
        Args:
            examples: List of (input_text, target_text) tuples
            tokenizer: Tokenizer instance
            max_length: Maximum sequence length
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.examples = examples

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        input_text, target_text = self.examples[idx]

        # Tokenize input and target
        input_encoding = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        target_encoding = self.tokenizer(
            target_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': input_encoding['input_ids'].squeeze(),
            'attention_mask': input_encoding['attention_mask'].squeeze(),
            'labels': target_encoding['input_ids'].squeeze()
        }


def finetune_model(model, train_loader, val_loader, epochs=3, lr=5e-5):
    """
    Fine-tune a pre-trained model on a specific task.

    Args:
        model: Pre-trained model
        train_loader: Training data loader
        val_loader: Validation data loader
        epochs: Number of fine-tuning epochs
        lr: Learning rate (typically lower than pre-training)
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    best_val_loss = float('inf')

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0

        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)

        # Validation
        model.eval()
        val_loss = 0

        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                val_loss += outputs.loss.item()

        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{epochs}")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Val Loss: {avg_val_loss:.4f}")

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_finetuned_model.pt')
            print("  Saved best model!")


# Example: Fine-tune for summarization
finetune_examples = [
    (
        "Summarize: The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower.",
        "The Eiffel Tower is an iron tower in Paris designed by Gustave Eiffel's company."
    ),
    (
        "Summarize: Machine learning is a method of data analysis that automates analytical model building. It is a branch of artificial intelligence based on the idea that systems can learn from data, identify patterns and make decisions with minimal human intervention.",
        "Machine learning is an AI method that enables systems to learn from data and make decisions automatically."
    ),
    # ... more examples
]

# Create fine-tuning dataset
finetune_dataset = FineTuningDataset(finetune_examples, tokenizer)
finetune_loader = DataLoader(finetune_dataset, batch_size=4, shuffle=True)

# Fine-tune the model
# finetune_model(model, finetune_loader, val_loader, epochs=3, lr=5e-5)

Comparison: Pre-Training vs Fine-Tuning

python
import pandas as pd

comparison = pd.DataFrame({
    'Aspect': [
        'Objective',
        'Data',
        'Data Size',
        'Supervision',
        'Duration',
        'Learning Rate',
        'Compute Cost',
        'Who Does It'
    ],
    'Pre-Training': [
        'Learn general language understanding',
        'Raw text (web, books, etc.)',
        'Billions of tokens',
        'Self-supervised (no labels)',
        'Weeks to months',
        '1e-4 to 1e-3',
        'Very high ($M)',
        'Large organizations'
    ],
    'Fine-Tuning': [
        'Adapt to specific task',
        'Task-specific labeled data',
        'Thousands to millions of tokens',
        'Supervised (labeled)',
        'Hours to days',
        '1e-5 to 5e-5',
        'Low to moderate ($)',
        'Individual researchers/companies'
    ]
})

print(comparison.to_string(index=False))

Key Differences:

  1. Scale: Pre-training uses 1000x more data and compute
  2. Learning Rate: Fine-tuning uses ~10x lower learning rate to preserve pre-trained knowledge
  3. Objective: Pre-training learns general patterns; fine-tuning specializes
  4. Accessibility: Pre-training requires massive resources; fine-tuning is accessible to most researchers

When to Fine-Tune

Decision framework for when fine-tuning makes sense:

python
def should_finetune(
    task_specific_data_size,
    task_similarity_to_pretraining,
    available_compute,
    performance_requirements
):
    """
    Decision framework for whether to fine-tune.

    Args:
        task_specific_data_size: Amount of labeled data (examples)
        task_similarity_to_pretraining: How similar task is to pre-training (0-1)
        available_compute: Available GPU hours
        performance_requirements: Required task performance (0-1)

    Returns:
        Recommendation: dict with decision and reasoning
    """
    recommendations = []

    # Check data size
    if task_specific_data_size < 100:
        recommendations.append({
            'decision': 'Few-shot prompting or retrieval',
            'reason': 'Too few examples for effective fine-tuning',
            'confidence': 'high'
        })
    elif task_specific_data_size < 1000:
        recommendations.append({
            'decision': 'Parameter-efficient fine-tuning (LoRA/Adapters)',
            'reason': 'Moderate data; efficient methods prevent overfitting',
            'confidence': 'medium'
        })
    else:
        recommendations.append({
            'decision': 'Full fine-tuning possible',
            'reason': 'Sufficient data for full model adaptation',
            'confidence': 'high'
        })

    # Check task similarity
    if task_similarity_to_pretraining > 0.8:
        recommendations.append({
            'decision': 'Prompting may be sufficient',
            'reason': 'Task very similar to pre-training; model already has needed knowledge',
            'confidence': 'medium'
        })
    elif task_similarity_to_pretraining < 0.3:
        recommendations.append({
            'decision': 'Fine-tuning strongly recommended',
            'reason': 'Task very different from pre-training; adaptation needed',
            'confidence': 'high'
        })

    # Check compute constraints
    if available_compute < 10:  # GPU hours
        recommendations.append({
            'decision': 'Use parameter-efficient methods or smaller models',
            'reason': 'Limited compute budget',
            'confidence': 'high'
        })

    # Check performance requirements
    if performance_requirements > 0.9:
        recommendations.append({
            'decision': 'Fine-tuning required',
            'reason': 'High performance requirements need task-specific adaptation',
            'confidence': 'high'
        })

    return recommendations


# Example scenarios
scenarios = [
    {
        'name': 'Medical diagnosis from reports',
        'data_size': 5000,
        'similarity': 0.4,
        'compute': 50,
        'performance': 0.95
    },
    {
        'name': 'General question answering',
        'data_size': 100,
        'similarity': 0.9,
        'compute': 5,
        'performance': 0.7
    },
    {
        'name': 'Custom chatbot for company',
        'data_size': 2000,
        'similarity': 0.6,
        'compute': 20,
        'performance': 0.85
    }
]

for scenario in scenarios:
    print(f"\nScenario: {scenario['name']}")
    print(f"Data: {scenario['data_size']} examples")
    recommendations = should_finetune(
        scenario['data_size'],
        scenario['similarity'],
        scenario['compute'],
        scenario['performance']
    )

    print("\nRecommendations:")
    for i, rec in enumerate(recommendations, 1):
        print(f"{i}. {rec['decision']}")
        print(f"   Reason: {rec['reason']}")
        print(f"   Confidence: {rec['confidence']}")

When to Fine-Tune:

Fine-tune when:

  • You have 1000+ labeled examples
  • Task differs significantly from general language modeling
  • High performance is critical
  • Domain-specific knowledge is required (medical, legal, etc.)

Don't fine-tune when:

  • You have < 100 examples (use prompting instead)
  • Task is general and well-covered in pre-training
  • You lack compute resources
  • Prompting achieves acceptable results

Fine-Tuning Approaches Spectrum

python
import matplotlib.pyplot as plt
import numpy as np

# Create visualization of fine-tuning spectrum
fig, ax = plt.subplots(figsize=(12, 6))

approaches = [
    'Zero-Shot\nPrompting',
    'Few-Shot\nPrompting',
    'Adapter\nTuning',
    'LoRA',
    'Prefix\nTuning',
    'Full\nFine-Tuning',
    'Continued\nPre-Training'
]

# Metrics (normalized 0-1)
trainable_params = [0, 0, 0.05, 0.1, 0.15, 1.0, 1.0]
data_required = [0, 0.05, 0.3, 0.4, 0.35, 0.7, 0.9]
compute_cost = [0.01, 0.01, 0.2, 0.25, 0.3, 1.0, 1.2]
task_adaptation = [0.3, 0.5, 0.7, 0.75, 0.7, 0.95, 1.0]

x = np.arange(len(approaches))
width = 0.2

ax.bar(x - 1.5*width, trainable_params, width, label='Trainable Params', alpha=0.8)
ax.bar(x - 0.5*width, data_required, width, label='Data Required', alpha=0.8)
ax.bar(x + 0.5*width, compute_cost, width, label='Compute Cost', alpha=0.8)
ax.bar(x + 1.5*width, task_adaptation, width, label='Task Adaptation', alpha=0.8)

ax.set_xlabel('Approach', fontsize=12)
ax.set_ylabel('Relative Scale (0-1)', fontsize=12)
ax.set_title('Fine-Tuning Approaches: Tradeoffs', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(approaches)
ax.legend()
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

Transfer Learning Best Practices

python
class TransferLearningPipeline:
    """
    Best practices pipeline for transfer learning with LLMs.
    """

    def __init__(self, base_model_name, task_type):
        """
        Args:
            base_model_name: Pre-trained model to use
            task_type: Type of task (classification, generation, etc.)
        """
        self.base_model_name = base_model_name
        self.task_type = task_type

    def select_base_model(self):
        """Select appropriate base model for task."""
        recommendations = {
            'classification': ['bert-base', 'roberta-base'],
            'generation': ['gpt2', 't5-base'],
            'seq2seq': ['t5-base', 'bart-base'],
            'embedding': ['sentence-transformers/all-mpnet-base-v2']
        }

        print(f"Task: {self.task_type}")
        print(f"Recommended models: {recommendations.get(self.task_type, ['gpt2'])}")

    def prepare_data(self, examples, train_split=0.8):
        """
        Prepare data with proper train/val split.

        Best practices:
        - Stratified split for classification
        - Random split for generation
        - Minimum validation size: 500 examples or 20%
        """
        import random
        random.shuffle(examples)

        split_idx = int(len(examples) * train_split)
        train_data = examples[:split_idx]
        val_data = examples[split_idx:]

        print(f"Train examples: {len(train_data)}")
        print(f"Val examples: {len(val_data)}")

        # Ensure minimum validation size
        min_val_size = max(500, int(len(examples) * 0.2))
        if len(val_data) < min_val_size:
            print(f"Warning: Validation set too small (< {min_val_size})")

        return train_data, val_data

    def choose_learning_rate(self, model_size, data_size):
        """
        Choose appropriate learning rate.

        Rules of thumb:
        - Smaller models: 5e-5
        - Larger models (&gt;1B params): 1e-5
        - Less data: lower LR
        """
        if model_size > 1e9:  # > 1B parameters
            base_lr = 1e-5
        elif model_size > 100e6:  # > 100M parameters
            base_lr = 2e-5
        else:
            base_lr = 5e-5

        # Adjust for data size
        if data_size < 1000:
            base_lr *= 0.5

        print(f"Recommended learning rate: {base_lr:.0e}")
        return base_lr

    def set_training_hyperparameters(self, data_size):
        """
        Set training hyperparameters based on data size.
        """
        if data_size < 1000:
            epochs = 10
            batch_size = 8
            warmup_ratio = 0.1
        elif data_size < 10000:
            epochs = 5
            batch_size = 16
            warmup_ratio = 0.05
        else:
            epochs = 3
            batch_size = 32
            warmup_ratio = 0.02

        config = {
            'epochs': epochs,
            'batch_size': batch_size,
            'warmup_ratio': warmup_ratio,
            'weight_decay': 0.01,
            'gradient_accumulation_steps': 1
        }

        print("Recommended hyperparameters:")
        for key, value in config.items():
            print(f"  {key}: {value}")

        return config


# Example usage
pipeline = TransferLearningPipeline('gpt2', 'generation')
pipeline.select_base_model()

# Simulate data
examples = [('input', 'target')] * 5000
train_data, val_data = pipeline.prepare_data(examples)

model_size = 124e6  # GPT2 small
lr = pipeline.choose_learning_rate(model_size, len(train_data))
config = pipeline.set_training_hyperparameters(len(train_data))

Summary

Transfer learning revolutionized NLP by making powerful models accessible:

  1. Pre-training learns general language understanding (expensive, done once)
  2. Fine-tuning adapts to specific tasks (affordable, task-specific)
  3. When to fine-tune: 1000+ examples, task-specific needs, performance requirements
  4. Spectrum of approaches: From zero-shot prompting to full fine-tuning

Modern LLM development is almost always transfer learning: start with a pre-trained model and adapt it to your needs.