Back
intermediate

T5: Text-to-Text Transfer Transformer

Understand T5's unified text-to-text framework that treats every NLP task as text generation, enabling powerful transfer learning

18 min min read

T5: Text-to-Text Transfer Transformer

T5 (Text-to-Text Transfer Transformer) introduced a revolutionary paradigm: treat every NLP task as a text-to-text problem. Whether it's translation, summarization, classification, or question answering, the input is text and the output is text.

The Text-to-Text Framework

Text-to-Text Framework: A unified approach where every NLP task is formulated as converting input text to output text, allowing a single model architecture and training procedure to handle diverse tasks like translation, classification, and summarization.

Unified Paradigm

python
"""
Traditional Approach (task-specific architectures):
- Classification: Text → Class label
- Translation: Source text → Target text
- QA: Question + Context → Answer span
- Summarization: Document → Summary
(Each requires different output layers)

T5 Approach (unified text-to-text):
- Classification: "classify: [text]" → "positive"
- Translation: "translate English to German: [text]" → "[German text]"
- QA: "question: [q] context: [c]" → "[answer]"
- Summarization: "summarize: [document]" → "[summary]"
(Same architecture for all tasks!)
"""

class TextToTextExample:
    """Demonstrate T5's text-to-text formulation"""

    @staticmethod
    def format_task(task_name, input_text, output_text=None):
        """Format different tasks as text-to-text"""

        formats = {
            "translation": {
                "prefix": "translate English to French:",
                "input": input_text,
                "output": output_text or "Bonjour, comment allez-vous?"
            },
            "summarization": {
                "prefix": "summarize:",
                "input": input_text,
                "output": output_text or "[brief summary]"
            },
            "sentiment": {
                "prefix": "sentiment:",
                "input": input_text,
                "output": output_text or "positive"
            },
            "cola": {  # Grammatical acceptability
                "prefix": "cola sentence:",
                "input": input_text,
                "output": output_text or "acceptable"
            },
            "stsb": {  # Semantic similarity (0-5 scale)
                "prefix": "stsb sentence1: [s1] sentence2: [s2]",
                "input": input_text,
                "output": output_text or "4.2"
            }
        }

        return formats.get(task_name, {})

# Examples of text-to-text formatting
def demonstrate_t5_tasks():
    """Show how T5 formats different NLP tasks"""

    tasks = [
        {
            "task": "Translation",
            "input": "translate English to German: The house is wonderful.",
            "output": "Das Haus ist wunderbar."
        },
        {
            "task": "Summarization",
            "input": "summarize: The tech industry has seen unprecedented growth...",
            "output": "Tech industry shows strong growth in recent years."
        },
        {
            "task": "Sentiment Classification",
            "input": "sentiment: This movie was absolutely terrible!",
            "output": "negative"
        },
        {
            "task": "Question Answering",
            "input": "question: What is the capital of France? context: Paris is the capital and most populous city of France.",
            "output": "Paris"
        },
        {
            "task": "Grammar Acceptability",
            "input": "cola sentence: The book was reading by John.",
            "output": "unacceptable"
        }
    ]

    print("T5 Text-to-Text Task Formatting:\n")
    for task in tasks:
        print(f"{task['task']}:")
        print(f"  Input:  {task['input']}")
        print(f"  Output: {task['output']}\n")

demonstrate_t5_tasks()

Unified Architecture: T5's text-to-text framework means the same model weights and architecture can handle any task - you just need to format the input appropriately with task prefixes.

T5 Architecture

Model Design

python
"""
T5 Architecture Specifications:

T5 uses encoder-decoder transformer (like original Transformer)

Model Variants:
- T5-Small:  60M parameters  (6 layers each, d_model=512)
- T5-Base:   220M parameters (12 layers each, d_model=768)
- T5-Large:  770M parameters (24 layers each, d_model=1024)
- T5-3B:     3B parameters   (24 layers each, d_model=1024, d_ff=16384)
- T5-11B:    11B parameters  (24 layers each, d_model=1024, d_ff=65536)

Key architectural details:
- Relative position embeddings (not absolute)
- Simplified layer normalization (no bias, no subtraction)
- Uses GELU activation
- Pre-layer normalization
"""

import torch
import torch.nn as nn

class T5RelativePositionBias(nn.Module):
    """
    T5's relative position bias instead of absolute position embeddings

    Instead of adding position embeddings to inputs, T5 adds a bias
    to the attention logits based on relative positions
    """

    def __init__(self, num_heads, num_buckets=32, max_distance=128):
        super().__init__()
        self.num_heads = num_heads
        self.num_buckets = num_buckets
        self.max_distance = max_distance

        # Relative attention bias table
        self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)

    @staticmethod
    def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
        """Map relative positions to buckets"""
        # Half buckets for exact positions (0 to num_buckets//2 - 1)
        # Half buckets for log-spaced positions beyond that
        num_buckets //= 2
        ret = (relative_position >= 0).to(torch.long) * num_buckets
        n = torch.abs(relative_position)

        # Exact positions for small distances
        max_exact = num_buckets // 2
        is_small = n < max_exact

        # Log-spaced positions for large distances
        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) /
            torch.log(max_distance / max_exact) *
            (num_buckets - max_exact)
        ).long()
        val_if_large = torch.min(
            val_if_large,
            torch.full_like(val_if_large, num_buckets - 1)
        )

        ret = ret + torch.where(is_small, n, val_if_large)
        return ret

    def forward(self, query_length, key_length):
        """Compute relative position bias"""
        # Create relative position matrix
        query_pos = torch.arange(query_length, dtype=torch.long)
        key_pos = torch.arange(key_length, dtype=torch.long)

        relative_position = key_pos[None, :] - query_pos[:, None]

        # Map to buckets
        buckets = self._relative_position_bucket(
            relative_position,
            num_buckets=self.num_buckets,
            max_distance=self.max_distance
        )

        # Get bias values
        bias = self.relative_attention_bias(buckets)  # [q_len, k_len, heads]
        bias = bias.permute(2, 0, 1).unsqueeze(0)  # [1, heads, q_len, k_len]

        return bias

class T5LayerNorm(nn.Module):
    """
    T5's simplified layer normalization
    - No bias term
    - No mean subtraction (only variance normalization)
    """

    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, x):
        # Variance normalization only
        variance = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.eps)
        return self.weight * x

class T5Attention(nn.Module):
    """T5 multi-head attention with relative position bias"""

    def __init__(self, d_model, num_heads, is_decoder=False, has_relative_bias=True):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_kv = d_model // num_heads
        self.is_decoder = is_decoder

        # Q, K, V projections
        self.q = nn.Linear(d_model, d_model, bias=False)
        self.k = nn.Linear(d_model, d_model, bias=False)
        self.v = nn.Linear(d_model, d_model, bias=False)
        self.o = nn.Linear(d_model, d_model, bias=False)

        # Relative position bias (only in first layer of each block)
        if has_relative_bias:
            self.relative_position_bias = T5RelativePositionBias(num_heads)
        else:
            self.relative_position_bias = None

    def forward(self, hidden_states, key_value_states=None, mask=None,
                position_bias=None):
        batch_size, seq_len, _ = hidden_states.shape

        # Self-attention or cross-attention
        if key_value_states is None:
            key_value_states = hidden_states

        # Project Q, K, V
        q = self.q(hidden_states).view(batch_size, -1, self.num_heads, self.d_kv)
        k = self.k(key_value_states).view(batch_size, -1, self.num_heads, self.d_kv)
        v = self.v(key_value_states).view(batch_size, -1, self.num_heads, self.d_kv)

        # Transpose for attention: [batch, heads, seq, d_kv]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1))

        # Add position bias
        if position_bias is None and self.relative_position_bias is not None:
            position_bias = self.relative_position_bias(seq_len, k.size(2))

        if position_bias is not None:
            scores = scores + position_bias

        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Softmax and apply to values
        attn_weights = torch.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)

        # Reshape and project
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)
        output = self.o(attn_output)

        return output, position_bias

class T5Block(nn.Module):
    """T5 encoder/decoder block"""

    def __init__(self, d_model, num_heads, d_ff, is_decoder=False, dropout=0.1):
        super().__init__()
        self.is_decoder = is_decoder

        # Self-attention
        self.self_attn = T5Attention(d_model, num_heads, is_decoder, has_relative_bias=True)
        self.self_attn_norm = T5LayerNorm(d_model)

        # Cross-attention (decoder only)
        if is_decoder:
            self.cross_attn = T5Attention(d_model, num_heads, has_relative_bias=False)
            self.cross_attn_norm = T5LayerNorm(d_model)

        # Feed-forward
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model, bias=False),
            nn.Dropout(dropout)
        )
        self.ffn_norm = T5LayerNorm(d_model)

    def forward(self, hidden_states, encoder_hidden_states=None,
                self_attn_mask=None, cross_attn_mask=None, position_bias=None):

        # Self-attention with pre-normalization
        normed_hidden = self.self_attn_norm(hidden_states)
        attn_output, position_bias = self.self_attn(
            normed_hidden, mask=self_attn_mask, position_bias=position_bias
        )
        hidden_states = hidden_states + attn_output

        # Cross-attention (decoder only)
        if self.is_decoder and encoder_hidden_states is not None:
            normed_hidden = self.cross_attn_norm(hidden_states)
            cross_output, _ = self.cross_attn(
                normed_hidden,
                key_value_states=encoder_hidden_states,
                mask=cross_attn_mask
            )
            hidden_states = hidden_states + cross_output

        # Feed-forward
        normed_hidden = self.ffn_norm(hidden_states)
        ffn_output = self.ffn(normed_hidden)
        hidden_states = hidden_states + ffn_output

        return hidden_states, position_bias

# Model instantiation
def create_t5_model(variant='base'):
    """Create T5 model of specified size"""

    configs = {
        'small': {'d_model': 512, 'num_layers': 6, 'd_ff': 2048, 'num_heads': 8},
        'base': {'d_model': 768, 'num_layers': 12, 'd_ff': 3072, 'num_heads': 12},
        'large': {'d_model': 1024, 'num_layers': 24, 'd_ff': 4096, 'num_heads': 16},
    }

    config = configs[variant]
    print(f"T5-{variant.capitalize()} Configuration:")
    print(f"  Model dimension: {config['d_model']}")
    print(f"  Layers: {config['num_layers']}")
    print(f"  Feed-forward dimension: {config['d_ff']}")
    print(f"  Attention heads: {config['num_heads']}")

    return config

create_t5_model('base')

Relative Position Bias: Instead of adding position information to input embeddings, T5 adds learned biases to attention scores based on relative distances. This allows better generalization to sequences longer than those seen during training.

C4 Dataset

C4 (Colossal Clean Crawled Corpus): A massive 750GB dataset created from Common Crawl web data, cleaned using filters for language, quality, and appropriateness. The cleaned nature significantly improves model performance compared to raw web text.

T5's training data: the Colossal Clean Crawled Corpus.

Dataset Creation

python
"""
C4 Dataset (Colossal Clean Crawled Corpus):

Source: Common Crawl (web scraping project)
Size: ~750GB of cleaned English text

Cleaning Pipeline:
1. Language filtering (keep only English)
2. Remove duplicate lines
3. Remove sentences with "bad words"
4. Remove code (lines with curly braces)
5. Retain sentences ending with punctuation
6. Remove pages with < 5 sentences
7. Remove blacklisted websites

Result: High-quality, diverse web text for pre-training
"""

class C4DatasetProcessor:
    """Simulate C4 cleaning pipeline"""

    def __init__(self):
        self.min_sentences = 5
        self.bad_words = {'badword1', 'badword2'}  # Simplified

    def is_english(self, text):
        """Check if text is English (simplified)"""
        # In reality, uses langdetect or similar
        return True  # Placeholder

    def remove_duplicates(self, lines):
        """Remove duplicate lines"""
        seen = set()
        unique_lines = []

        for line in lines:
            if line not in seen:
                seen.add(line)
                unique_lines.append(line)

        return unique_lines

    def has_bad_words(self, text):
        """Check for inappropriate content"""
        words = text.lower().split()
        return any(word in self.bad_words for word in words)

    def is_code(self, line):
        """Detect code-like content"""
        # Simple heuristic: lines with { or }
        return '{' in line or '}' in line

    def ends_with_punctuation(self, sentence):
        """Check if sentence ends with terminal punctuation"""
        return sentence.rstrip().endswith(('.', '!', '?'))

    def clean_page(self, text):
        """Apply C4 cleaning pipeline to a web page"""

        # Split into lines
        lines = text.split('\n')

        # Remove duplicates
        lines = self.remove_duplicates(lines)

        # Filter lines
        clean_lines = []
        for line in lines:
            line = line.strip()

            # Skip empty lines
            if not line:
                continue

            # Skip code
            if self.is_code(line):
                continue

            # Skip bad words
            if self.has_bad_words(line):
                continue

            # Keep only sentences ending with punctuation
            if self.ends_with_punctuation(line):
                clean_lines.append(line)

        # Reject pages with too few sentences
        if len(clean_lines) < self.min_sentences:
            return None

        return '\n'.join(clean_lines)

# Demonstrate cleaning
def demonstrate_c4_cleaning():
    """Show C4 cleaning process"""

    processor = C4DatasetProcessor()

    raw_text = """
    Welcome to our website!
    This is a great product
    Check out this code: function() { return true; }
    We offer the best services.
    Click here now!
    Buy now for only $9.99!
    """

    print("Raw text:")
    print(raw_text)
    print("\nAfter C4 cleaning:")

    cleaned = processor.clean_page(raw_text)
    if cleaned:
        print(cleaned)
    else:
        print("Page rejected (too few sentences)")

demonstrate_c4_cleaning()

# C4 statistics
c4_stats = """
C4 Dataset Statistics:

Size:
- ~750GB of text
- ~365 billion tokens
- ~156 billion words

Coverage:
- Sourced from April 2019 Common Crawl
- ~15 million domains
- Highly diverse topics and styles

Comparison to other datasets:
- Wikipedia: ~20GB
- BookCorpus: ~5GB
- C4: ~750GB (37x larger than Wikipedia)

Quality vs Scale:
- More data generally improves performance
- But data quality matters too
- C4 balances both: large AND clean
"""

print(c4_stats)

Data Quality Matters: T5 experiments showed that training on C4 (cleaned web text) significantly outperformed training on raw Common Crawl, demonstrating that data quality is as important as quantity.

Pre-training Objectives

Span Corruption: T5's pre-training objective where contiguous spans of tokens are replaced with sentinel tokens, and the model must predict the original content at each sentinel position. This is more effective than single-token masking for learning context.

T5 explored multiple pre-training approaches.

Objective Comparison

python
"""
T5 evaluated multiple pre-training objectives:

1. BERT-style (Fill in the blank):
   Input:  "Thank you for inviting me to your party last week."
   Target: "me to your party"

2. I.I.D. Denoising:
   Input:  "Thank you <M> inviting <M> to <M> party <M> week."
   Target: "<M> for <M> me <M> your <M> last <M>"

3. Replace Corrupted Spans (CHOSEN):
   Input:  "Thank you <X> to <Y> party <Z> week."
   Target: "<X> for inviting me <Y> your <Z> last <X>"

Result: Replace spans worked best!
"""

class T5PretrainingObjective:
    """T5's span corruption pre-training"""

    def __init__(self, corruption_rate=0.15, mean_span_length=3):
        self.corruption_rate = corruption_rate
        self.mean_span_length = mean_span_length

    def corrupt_spans(self, text, vocab):
        """
        Corrupt spans of tokens for T5 pre-training

        Args:
            text: List of tokens
            vocab: Vocabulary (for sentinel tokens)

        Returns:
            corrupted_input: Input with corrupted spans replaced by sentinels
            targets: Target sequence with sentinels and original spans
        """
        import numpy as np

        tokens = text.split()
        num_tokens = len(tokens)

        # Determine which tokens to corrupt
        num_corrupt = int(num_tokens * self.corruption_rate)

        # Sample span starts
        span_starts = []
        corrupted_indices = set()

        while len(corrupted_indices) < num_corrupt:
            # Sample span start
            start = np.random.randint(0, num_tokens)

            # Sample span length (geometric distribution)
            length = np.random.geometric(1.0 / self.mean_span_length)
            length = min(length, num_tokens - start)

            # Add span indices
            span_indices = range(start, start + length)
            span_starts.append(start)
            corrupted_indices.update(span_indices)

        # Sort spans by position
        span_starts.sort()

        # Create input and target
        input_tokens = []
        target_tokens = []

        sentinel_id = 0
        last_end = 0

        for start in span_starts:
            # Add uncorrupted tokens before span
            input_tokens.extend(tokens[last_end:start])

            # Add sentinel to input
            sentinel = f"<extra_id_{sentinel_id}>"
            input_tokens.append(sentinel)

            # Find span end
            end = start + 1
            while end in corrupted_indices:
                end += 1

            # Add sentinel and original span to target
            target_tokens.append(sentinel)
            target_tokens.extend(tokens[start:end])

            sentinel_id += 1
            last_end = end

        # Add remaining uncorrupted tokens
        input_tokens.extend(tokens[last_end:])

        # Add final sentinel to target
        target_tokens.append(f"<extra_id_{sentinel_id}>")

        return ' '.join(input_tokens), ' '.join(target_tokens)

# Demonstrate span corruption
def demonstrate_span_corruption():
    """Show T5's span corruption objective"""

    objective = T5PretrainingObjective(corruption_rate=0.15, mean_span_length=3)

    original = "Thank you for inviting me to your party last week"

    print("T5 Span Corruption Pre-training:\n")
    print(f"Original: {original}\n")

    # Generate multiple corrupted versions
    for i in range(3):
        corrupted_input, target = objective.corrupt_spans(original, vocab={})
        print(f"Example {i+1}:")
        print(f"  Input:  {corrupted_input}")
        print(f"  Target: {target}\n")

demonstrate_span_corruption()

# Objective comparison from T5 paper
objective_comparison = """
T5 Pre-training Objective Ablation Results:

Objective                   | GLUE Score | SQuAD EM
---------------------------|------------|----------
BERT-style (full mask)     | 83.2       | 80.1
Deshuffling               | 82.9       | 79.5
MASS (50% mask)           | 83.7       | 80.8
Replace Spans (T5)        | 84.1       | 81.3  ← Best!

Corruption Rate Ablation:
- 10%: Good
- 15%: Better
- 25%: Best
- 50%: Worse (too much corruption)

Mean Span Length:
- 1 token:  Similar to BERT
- 3 tokens: Best performance ← T5 default
- 10 tokens: Worse

Conclusion: Replace corrupted spans with 15% corruption
and mean span length of 3 tokens works best.
"""

print(objective_comparison)

Span Corruption Insight: Corrupting spans instead of individual tokens forces the model to understand longer-range context and dependencies, leading to better representations.

Transfer Learning with T5

Multi-task Training

python
"""
T5 Training Strategy:

1. Pre-training:
   - Unsupervised span corruption on C4
   - Learns general language understanding

2. Multi-task Fine-tuning (optional):
   - Train on mixture of supervised tasks
   - Format all as text-to-text
   - Improves generalization

3. Task-specific Fine-tuning:
   - Fine-tune on target task
   - Usually gives best performance
"""

class T5MultiTaskTraining:
    """T5 multi-task training setup"""

    def __init__(self):
        self.tasks = {
            'translation': {
                'prefix': 'translate English to French:',
                'dataset': 'WMT',
                'examples': 1000000
            },
            'summarization': {
                'prefix': 'summarize:',
                'dataset': 'CNN/DailyMail',
                'examples': 300000
            },
            'question_answering': {
                'prefix': 'question: {q} context: {c}',
                'dataset': 'SQuAD',
                'examples': 100000
            },
            'sentiment': {
                'prefix': 'sentiment:',
                'dataset': 'SST-2',
                'examples': 67000
            }
        }

    def create_task_mixture(self, mixing_strategy='proportional'):
        """
        Create mixture of tasks for multi-task training

        Strategies:
        - equal: Sample equally from each task
        - proportional: Sample proportional to dataset size
        - temperature: Use temperature to adjust sampling
        """

        if mixing_strategy == 'equal':
            # Each task has equal probability
            for task in self.tasks:
                self.tasks[task]['weight'] = 1.0 / len(self.tasks)

        elif mixing_strategy == 'proportional':
            # Sample proportional to dataset size
            total_examples = sum(t['examples'] for t in self.tasks.values())
            for task in self.tasks:
                self.tasks[task]['weight'] = (
                    self.tasks[task]['examples'] / total_examples
                )

        elif mixing_strategy == 'temperature':
            # Use temperature to flatten/sharpen distribution
            temperature = 2.0
            total = sum(t['examples'] ** (1/temperature) for t in self.tasks.values())

            for task in self.tasks:
                weight = (self.tasks[task]['examples'] ** (1/temperature)) / total
                self.tasks[task]['weight'] = weight

        return self.tasks

# Demonstrate multi-task training
def demonstrate_multitask():
    """Show T5 multi-task training setup"""

    trainer = T5MultiTaskTraining()

    print("T5 Multi-task Training:\n")

    for strategy in ['equal', 'proportional', 'temperature']:
        print(f"{strategy.capitalize()} mixing:")
        tasks = trainer.create_task_mixture(strategy)

        for task_name, task_info in tasks.items():
            print(f"  {task_name}: {task_info['weight']:.3f}")
        print()

demonstrate_multitask()

Fine-tuning Examples

python
# Fine-tune T5 for specific tasks
def finetune_t5_summarization():
    """Fine-tune T5 for summarization"""
    from transformers import T5ForConditionalGeneration, T5Tokenizer

    model = T5ForConditionalGeneration.from_pretrained('t5-base')
    tokenizer = T5Tokenizer.from_pretrained('t5-base')

    # Example training data
    documents = [
        "The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower.",
        "Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed."
    ]

    summaries = [
        "The Eiffel Tower in Paris was designed by Gustave Eiffel.",
        "Machine learning allows systems to learn from experience without explicit programming."
    ]

    print("T5 Summarization Fine-tuning:\n")

    for doc, summary in zip(documents, summaries):
        # Format as text-to-text
        input_text = f"summarize: {doc}"

        # Tokenize
        input_ids = tokenizer(input_text, return_tensors='pt',
                             max_length=512, truncation=True).input_ids
        labels = tokenizer(summary, return_tensors='pt',
                          max_length=128, truncation=True).input_ids

        # Training step (simplified)
        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss

        print(f"Input length: {input_ids.shape[1]}")
        print(f"Target length: {labels.shape[1]}")
        print(f"Loss: {loss.item():.4f}\n")

    # Inference
    print("Inference Example:")
    test_doc = "Neural networks are computing systems inspired by biological neural networks. They learn from examples without being programmed with task-specific rules."

    input_ids = tokenizer(f"summarize: {test_doc}",
                         return_tensors='pt').input_ids

    outputs = model.generate(input_ids, max_length=50, num_beams=4,
                            early_stopping=True)

    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Document: {test_doc}")
    print(f"Summary: {summary}")

finetune_t5_summarization()

# Translation example
def t5_translation_example():
    """Use T5 for translation"""
    from transformers import T5ForConditionalGeneration, T5Tokenizer

    model = T5ForConditionalGeneration.from_pretrained('t5-base')
    tokenizer = T5Tokenizer.from_pretrained('t5-base')

    # English to German
    text = "translate English to German: The house is wonderful."
    input_ids = tokenizer(text, return_tensors='pt').input_ids

    outputs = model.generate(input_ids, max_length=40)
    translation = tokenizer.decode(outputs[0], skip_special_tokens=True)

    print(f"\nT5 Translation:")
    print(f"English: The house is wonderful.")
    print(f"German: {translation}")

t5_translation_example()

T5 Variants

python
"""
T5 Model Family:

1. T5 (2019):
   - Original text-to-text model
   - 5 sizes: Small (60M) to 11B
   - Trained on C4

2. mT5 (Multilingual T5, 2020):
   - Supports 101 languages
   - Trained on mC4 (multilingual C4)
   - Same architecture, multilingual data

3. ByT5 (Byte-level T5, 2021):
   - Operates on raw bytes instead of subword tokens
   - No vocabulary needed
   - Better for multilingual and noisy text

4. Flan-T5 (2022):
   - Instruction-tuned T5
   - Trained on 1,800+ tasks formatted as instructions
   - Much better zero-shot and few-shot performance

5. UL2 (Unified Language Learner, 2022):
   - Unifies different denoising objectives
   - Better performance than T5
   - Later adapted to T5 architecture (Flan-UL2)
"""

# Model comparison
def compare_t5_variants():
    """Compare T5 model variants"""

    variants = {
        'T5-Base': {
            'params': '220M',
            'languages': '1 (English)',
            'tokenization': 'SentencePiece (32K vocab)',
            'use_case': 'General English NLP'
        },
        'mT5-Base': {
            'params': '580M',
            'languages': '101',
            'tokenization': 'SentencePiece (250K vocab)',
            'use_case': 'Multilingual tasks'
        },
        'ByT5-Base': {
            'params': '300M',
            'languages': 'All (byte-level)',
            'tokenization': 'Bytes (256 vocab)',
            'use_case': 'Noisy/multilingual text'
        },
        'Flan-T5-Base': {
            'params': '250M',
            'languages': '1 (English)',
            'tokenization': 'SentencePiece (32K vocab)',
            'use_case': 'Instruction following'
        }
    }

    print("T5 Variants Comparison:\n")
    for name, specs in variants.items():
        print(f"{name}:")
        for key, value in specs.items():
            print(f"  {key}: {value}")
        print()

compare_t5_variants()

# Flan-T5 instruction format
flan_t5_examples = """
Flan-T5 Instruction Format:

Instead of task prefixes, uses natural instructions:

Standard T5:
"translate English to French: Hello"

Flan-T5:
"Please translate the following sentence to French: Hello"
OR
"Convert this English text to French: Hello"
OR
"What is the French translation of: Hello"

Benefits:
- More natural interaction
- Better zero-shot generalization
- Instruction diversity improves robustness
- Can follow novel instructions not seen during training
"""

print(flan_t5_examples)

Flan-T5 Performance: Flan-T5 shows remarkable zero-shot performance on tasks it wasn't explicitly trained on, often matching or exceeding much larger models. This makes it highly practical for real-world applications.

Practice Exercise

python
# Exercise: Implement task-specific formatters for T5
class T5TaskFormatter:
    """Format different NLP tasks for T5"""

    @staticmethod
    def format_classification(text, label=None, task_name="sentiment"):
        """Format classification task"""
        input_text = f"{task_name}: {text}"
        output_text = label if label else None
        return input_text, output_text

    @staticmethod
    def format_ner(text, entities=None):
        """Format NER task"""
        input_text = f"extract entities: {text}"

        if entities:
            # Format: "PER: John, ORG: Google"
            entity_strings = [f"{ent['type']}: {ent['text']}"
                            for ent in entities]
            output_text = ", ".join(entity_strings)
        else:
            output_text = None

        return input_text, output_text

    @staticmethod
    def format_qa(question, context, answer=None):
        """Format QA task"""
        input_text = f"question: {question} context: {context}"
        output_text = answer if answer else None
        return input_text, output_text

    @staticmethod
    def format_paraphrase(text, paraphrase=None):
        """Format paraphrase generation"""
        input_text = f"paraphrase: {text}"
        output_text = paraphrase if paraphrase else None
        return input_text, output_text

# Demonstrate formatting
def demonstrate_task_formatting():
    """Show T5 task formatting"""

    formatter = T5TaskFormatter()

    print("T5 Task Formatting Examples:\n")

    # Classification
    input_text, output = formatter.format_classification(
        "This product is amazing!",
        label="positive",
        task_name="sentiment"
    )
    print("Classification:")
    print(f"  Input:  {input_text}")
    print(f"  Output: {output}\n")

    # NER
    entities = [
        {"type": "PER", "text": "John Smith"},
        {"type": "ORG", "text": "Google"}
    ]
    input_text, output = formatter.format_ner(
        "John Smith works at Google.",
        entities=entities
    )
    print("Named Entity Recognition:")
    print(f"  Input:  {input_text}")
    print(f"  Output: {output}\n")

    # QA
    input_text, output = formatter.format_qa(
        question="What is the capital of France?",
        context="Paris is the capital and largest city of France.",
        answer="Paris"
    )
    print("Question Answering:")
    print(f"  Input:  {input_text}")
    print(f"  Output: {output}\n")

demonstrate_task_formatting()

# Exercise questions
exercise_questions = """
Practice Exercises:

1. Why does T5 use relative position bias instead of absolute
   position embeddings? What advantage does this provide?

2. Design a text-to-text format for a table-to-text task
   (converting structured data to natural language).

3. Calculate: If T5-Base has 12 encoder and 12 decoder layers,
   and each layer has ~18M parameters, estimate total parameters.

4. Compare: When would you use T5 vs BERT vs GPT? List use cases
   for each architecture.

5. Implement: Create a corrupted span example with 20% corruption
   rate and mean span length of 2 tokens.
"""

print(exercise_questions)

Quiz

Further Reading