Back
advanced
Advanced Fine-Tuning

DPO: Direct Preference Optimization

DPO simplifies RLHF by directly optimizing for preferences without reward models or RL. Learn the theory, implementation, and advantages over traditional RLHF.

25 min read· DPO· Preference Learning· RLHF· Alignment

DPO: Direct Preference Optimization

DPO (Direct Preference Optimization) is a simpler, more stable alternative to RLHF that directly optimizes language models from preference data without reward models or reinforcement learning.

The Problem with RLHF

RLHF's complexity creates challenges:

python
class RLHFvsDP:
    """
    Compare RLHF and DPO approaches.
    """

    def compare_pipelines(self):
        """Compare RLHF vs DPO pipelines."""
        print("RLHF Pipeline:")
        print("  1. Train SFT model")
        print("  2. Collect preference data")
        print("  3. Train reward model")
        print("  4. Use PPO to optimize policy")
        print("\n  Challenges:")
        print("    - Reward model can be inaccurate")
        print("    - PPO training is unstable")
        print("    - Requires 3 models in memory")
        print("    - Hyperparameter sensitive")
        print("    - Reward hacking possible")
        print()

        print("DPO Pipeline:")
        print("  1. Train SFT model")
        print("  2. Collect preference data")
        print("  3. Directly optimize policy on preferences")
        print("\n  Advantages:")
        print("    - No reward model needed")
        print("    - No RL needed")
        print("    - Only 2 models in memory (policy + reference)")
        print("    - Simpler, more stable")
        print("    - Direct optimization")

comparer = RLHFvsDP()
comparer.compare_pipelines()

Key Insight:

RLHF's reward model is an intermediate step - we use it to train the policy. DPO asks: Can we skip the reward model and directly optimize for preferences?

Answer: Yes! DPO reparameterizes the RLHF objective to enable direct optimization.

DPO Theory

From RLHF to DPO

RLHF objective:

max E[r(x,y)] - β KL(π_θ || π_ref)

where:
- r(x,y): reward model score for response y to prompt x
- π_θ: policy being optimized
- π_ref: reference policy (SFT model)
- β: KL coefficient

DPO insight: The optimal policy has closed form!

π*(y|x) = 1/Z(x) * π_ref(y|x) * exp(r(x,y)/β)

where Z(x) is partition function

This means we can express the reward in terms of policies:

r(x,y) = β log(π*(y|x) / π_ref(y|x)) + β log Z(x)

DPO Loss Function

Using the Bradley-Terry preference model:

P(y_w > y_l | x) = σ(r(x,y_w) - r(x,y_l))

where:
- y_w: winning (chosen) response
- y_l: losing (rejected) response
- σ: sigmoid function

Substituting the reward reparameterization:

python
import torch
import torch.nn as nn
import torch.nn.functional as F

def dpo_loss_formula():
    """
    Explain DPO loss mathematically.
    """
    print("DPO Loss Formula:")
    print()
    print("L_DPO(π_θ; π_ref) = -E[(x,y_w,y_l) ~ D] [")
    print("  log σ(β log(π_θ(y_w|x)/π_ref(y_w|x)) - β log(π_θ(y_l|x)/π_ref(y_l|x)))")
    print("]")
    print()
    print("Where:")
    print("  - π_θ: policy being optimized")
    print("  - π_ref: reference policy (frozen SFT model)")
    print("  - y_w: chosen response")
    print("  - y_l: rejected response")
    print("  - β: temperature parameter")
    print("  - σ: sigmoid function")
    print()
    print("Intuition:")
    print("  Increase probability ratio π_θ/π_ref for chosen responses")
    print("  Decrease probability ratio π_θ/π_ref for rejected responses")
    print("  β controls how much policy can deviate from reference")

dpo_loss_formula()

DPO Loss Intuition:

The loss encourages the policy to:

  1. Increase likelihood of chosen responses (y_w) relative to reference
  2. Decrease likelihood of rejected responses (y_l) relative to reference
  3. Maintain balance controlled by β

All without explicitly computing rewards!

DPO Implementation

Complete DPO Trainer

python
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
from typing import List, Dict
import torch

@dataclass
class PreferenceExample:
    """Preference data example."""
    prompt: str
    chosen: str
    rejected: str


class DPODataset(Dataset):
    """Dataset for DPO training."""

    def __init__(
        self,
        examples: List[PreferenceExample],
        tokenizer,
        max_length: int = 512
    ):
        """
        Args:
            examples: List of preference examples
            tokenizer: Tokenizer
            max_length: Maximum sequence length
        """
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        """
        Tokenize prompt, chosen, and rejected responses.
        """
        example = self.examples[idx]

        # Combine prompt with responses
        chosen_text = f"{example.prompt}\n{example.chosen}"
        rejected_text = f"{example.prompt}\n{example.rejected}"

        # Tokenize
        chosen_tokens = self.tokenizer(
            chosen_text,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )

        rejected_tokens = self.tokenizer(
            rejected_text,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )

        return {
            'chosen_input_ids': chosen_tokens['input_ids'].squeeze(),
            'chosen_attention_mask': chosen_tokens['attention_mask'].squeeze(),
            'rejected_input_ids': rejected_tokens['input_ids'].squeeze(),
            'rejected_attention_mask': rejected_tokens['attention_mask'].squeeze(),
        }


class DPOTrainer:
    """
    Direct Preference Optimization trainer.

    Simpler alternative to RLHF - no reward model or RL needed!
    """

    def __init__(
        self,
        model_name: str,
        beta: float = 0.1,
        use_lora: bool = True,
        lora_rank: int = 8
    ):
        """
        Args:
            model_name: Base model (should be SFT model)
            beta: DPO temperature parameter
            use_lora: Whether to use LoRA
            lora_rank: LoRA rank if using LoRA
        """
        self.beta = beta
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Load policy model (will be trained)
        self.policy_model = AutoModelForCausalLM.from_pretrained(model_name)

        # Apply LoRA if requested
        if use_lora:
            from peft import LoraConfig, get_peft_model

            lora_config = LoraConfig(
                r=lora_rank,
                lora_alpha=16,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM"
            )
            self.policy_model = get_peft_model(self.policy_model, lora_config)
            self.policy_model.print_trainable_parameters()

        # Load reference model (frozen)
        self.ref_model = AutoModelForCausalLM.from_pretrained(model_name)
        for param in self.ref_model.parameters():
            param.requires_grad = False

        # Move to device
        self.policy_model.to(self.device)
        self.ref_model.to(self.device)

    def get_log_probs(
        self,
        model,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Get log probabilities of sequences under model.

        Args:
            model: Language model
            input_ids: Token IDs (batch, seq_len)
            attention_mask: Attention mask

        Returns:
            Log probability of each sequence
        """
        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        # Shift logits and labels for next-token prediction
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        shift_attention_mask = attention_mask[:, 1:].contiguous()

        # Get log probabilities
        log_probs = F.log_softmax(shift_logits, dim=-1)

        # Gather log probs for actual tokens
        # Shape: (batch, seq_len - 1)
        token_log_probs = torch.gather(
            log_probs,
            dim=2,
            index=shift_labels.unsqueeze(-1)
        ).squeeze(-1)

        # Mask padding tokens
        token_log_probs = token_log_probs * shift_attention_mask

        # Sum log probs over sequence
        seq_log_probs = token_log_probs.sum(dim=1) / shift_attention_mask.sum(dim=1)

        return seq_log_probs

    def compute_dpo_loss(
        self,
        chosen_input_ids: torch.Tensor,
        chosen_attention_mask: torch.Tensor,
        rejected_input_ids: torch.Tensor,
        rejected_attention_mask: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        Compute DPO loss.

        L = -E[log σ(β * log(π/π_ref)_chosen - β * log(π/π_ref)_rejected)]

        Args:
            chosen_input_ids: Chosen response token IDs
            chosen_attention_mask: Chosen response attention mask
            rejected_input_ids: Rejected response token IDs
            rejected_attention_mask: Rejected response attention mask

        Returns:
            Dict with loss and metrics
        """
        # Get log probs from policy model
        policy_chosen_log_probs = self.get_log_probs(
            self.policy_model, chosen_input_ids, chosen_attention_mask
        )
        policy_rejected_log_probs = self.get_log_probs(
            self.policy_model, rejected_input_ids, rejected_attention_mask
        )

        # Get log probs from reference model
        with torch.no_grad():
            ref_chosen_log_probs = self.get_log_probs(
                self.ref_model, chosen_input_ids, chosen_attention_mask
            )
            ref_rejected_log_probs = self.get_log_probs(
                self.ref_model, rejected_input_ids, rejected_attention_mask
            )

        # Compute log ratios: log(π_θ / π_ref)
        chosen_log_ratio = policy_chosen_log_probs - ref_chosen_log_probs
        rejected_log_ratio = policy_rejected_log_probs - ref_rejected_log_probs

        # DPO loss: -log σ(β * (log_ratio_chosen - log_ratio_rejected))
        logits = self.beta * (chosen_log_ratio - rejected_log_ratio)
        loss = -F.logsigmoid(logits).mean()

        # Compute metrics
        with torch.no_grad():
            # Implicit reward
            chosen_rewards = self.beta * chosen_log_ratio
            rejected_rewards = self.beta * rejected_log_ratio

            # Accuracy: how often chosen > rejected
            accuracy = (chosen_rewards > rejected_rewards).float().mean()

        return {
            'loss': loss,
            'chosen_rewards': chosen_rewards.mean(),
            'rejected_rewards': rejected_rewards.mean(),
            'accuracy': accuracy
        }

    def train(
        self,
        train_examples: List[PreferenceExample],
        val_examples: List[PreferenceExample],
        epochs: int = 3,
        batch_size: int = 4,
        learning_rate: float = 5e-7
    ):
        """
        Train policy with DPO.

        Args:
            train_examples: Training preference data
            val_examples: Validation preference data
            epochs: Number of epochs
            batch_size: Batch size
            learning_rate: Learning rate (typically lower than SFT)
        """
        # Create datasets
        train_dataset = DPODataset(train_examples, self.tokenizer)
        val_dataset = DPODataset(val_examples, self.tokenizer)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)

        # Optimizer
        optimizer = torch.optim.AdamW(
            [p for p in self.policy_model.parameters() if p.requires_grad],
            lr=learning_rate
        )

        best_val_loss = float('inf')

        for epoch in range(epochs):
            # Training
            self.policy_model.train()
            train_metrics = {
                'loss': 0,
                'chosen_rewards': 0,
                'rejected_rewards': 0,
                'accuracy': 0
            }

            num_batches = 0

            for batch in train_loader:
                # Move to device
                chosen_input_ids = batch['chosen_input_ids'].to(self.device)
                chosen_attention_mask = batch['chosen_attention_mask'].to(self.device)
                rejected_input_ids = batch['rejected_input_ids'].to(self.device)
                rejected_attention_mask = batch['rejected_attention_mask'].to(self.device)

                # Compute DPO loss
                metrics = self.compute_dpo_loss(
                    chosen_input_ids,
                    chosen_attention_mask,
                    rejected_input_ids,
                    rejected_attention_mask
                )

                loss = metrics['loss']

                # Backward pass
                optimizer.zero_grad()
                loss.backward()

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(
                    [p for p in self.policy_model.parameters() if p.requires_grad],
                    max_norm=1.0
                )

                optimizer.step()

                # Accumulate metrics
                for key in train_metrics:
                    train_metrics[key] += metrics[key].item()
                num_batches += 1

            # Average metrics
            for key in train_metrics:
                train_metrics[key] /= num_batches

            # Validation
            val_metrics = self.validate(val_loader)

            print(f"\nEpoch {epoch+1}/{epochs}")
            print(f"  Train Loss: {train_metrics['loss']:.4f}")
            print(f"  Train Accuracy: {train_metrics['accuracy']:.2%}")
            print(f"  Train Chosen Rewards: {train_metrics['chosen_rewards']:.4f}")
            print(f"  Train Rejected Rewards: {train_metrics['rejected_rewards']:.4f}")
            print(f"  Val Loss: {val_metrics['loss']:.4f}")
            print(f"  Val Accuracy: {val_metrics['accuracy']:.2%}")

            if val_metrics['loss'] < best_val_loss:
                best_val_loss = val_metrics['loss']
                self.save_model('best_dpo_model')
                print("  Saved best model!")

    def validate(self, val_loader):
        """Validate the model."""
        self.policy_model.eval()

        val_metrics = {
            'loss': 0,
            'chosen_rewards': 0,
            'rejected_rewards': 0,
            'accuracy': 0
        }
        num_batches = 0

        with torch.no_grad():
            for batch in val_loader:
                chosen_input_ids = batch['chosen_input_ids'].to(self.device)
                chosen_attention_mask = batch['chosen_attention_mask'].to(self.device)
                rejected_input_ids = batch['rejected_input_ids'].to(self.device)
                rejected_attention_mask = batch['rejected_attention_mask'].to(self.device)

                metrics = self.compute_dpo_loss(
                    chosen_input_ids,
                    chosen_attention_mask,
                    rejected_input_ids,
                    rejected_attention_mask
                )

                for key in val_metrics:
                    val_metrics[key] += metrics[key].item()
                num_batches += 1

        for key in val_metrics:
            val_metrics[key] /= num_batches

        return val_metrics

    def save_model(self, path: str):
        """Save the DPO-trained model."""
        self.policy_model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)


# Example usage
print("Creating DPO trainer...")
# trainer = DPOTrainer("gpt2", beta=0.1, use_lora=True)

# Example preference data
example_preferences = [
    PreferenceExample(
        prompt="Explain gravity to a child.",
        chosen="Gravity is like an invisible force that pulls things down to Earth! It's why when you drop a ball, it falls to the ground instead of floating away. Everything with mass has gravity - even you have a tiny bit! The Earth is so big and heavy that its gravity is strong enough to keep us and everything else from floating into space.",
        rejected="Gravity is a fundamental force described by Einstein's general relativity as the curvature of spacetime caused by mass-energy. The gravitational field strength is proportional to mass and inversely proportional to the square of the distance."
    ),
    # ... more examples
]

# trainer.train(train_examples, val_examples, epochs=3)

DPO Hyperparameters:

  • β (beta): Temperature parameter (0.1 - 0.5)

    • Lower β: More conservative, stays closer to reference
    • Higher β: More aggressive optimization
    • Start with β=0.1
  • Learning rate: Lower than SFT (1e-7 to 5e-6)

    • Too high: Unstable, policy diverges
    • Too low: Slow convergence
  • Batch size: Larger is better (4-16)

    • More stable gradient estimates
    • Better use of preference data

DPO vs RLHF Comparison

python
import pandas as pd

comparison = pd.DataFrame({
    'Aspect': [
        'Stages',
        'Models needed',
        'Training stability',
        'Implementation complexity',
        'Memory usage',
        'Training speed',
        'Performance',
        'Hyperparameter sensitivity'
    ],
    'RLHF (PPO)': [
        '3 (SFT, RM, PPO)',
        '3 (policy, ref, reward)',
        'Unstable',
        'High',
        'High (3 models)',
        'Slow',
        'Strong',
        'Very sensitive'
    ],
    'DPO': [
        '2 (SFT, DPO)',
        '2 (policy, ref)',
        'Stable',
        'Medium',
        'Medium (2 models)',
        'Fast',
        'Comparable',
        'Less sensitive'
    ]
})

print("\nRLHF vs DPO Comparison:")
print(comparison.to_string(index=False))

print("\n" + "="*70)
print("When to use each:")
print("="*70)
print("Use RLHF when:")
print("  - You need explicit reward signals")
print("  - You have complex reward shaping requirements")
print("  - You want to combine multiple reward sources")
print()
print("Use DPO when:")
print("  - You want simplicity and stability")
print("  - You have good preference data")
print("  - You want faster training")
print("  - You have limited compute resources")

Summary

DPO simplifies alignment by:

  1. Eliminating reward model: Directly optimize from preferences
  2. Removing RL: Supervised learning on preference data
  3. Maintaining performance: Comparable to RLHF
  4. Improving stability: More stable training dynamics

DPO is becoming the preferred method for preference-based alignment due to its simplicity and effectiveness.