T5: Text-to-Text Transfer Transformer
T5 (Text-to-Text Transfer Transformer) introduced a revolutionary paradigm: treat every NLP task as a text-to-text problem. Whether it's translation, summarization, classification, or question answering, the input is text and the output is text.
The Text-to-Text Framework
Text-to-Text Framework: A unified approach where every NLP task is formulated as converting input text to output text, allowing a single model architecture and training procedure to handle diverse tasks like translation, classification, and summarization.
Unified Paradigm
"""
Traditional Approach (task-specific architectures):
- Classification: Text → Class label
- Translation: Source text → Target text
- QA: Question + Context → Answer span
- Summarization: Document → Summary
(Each requires different output layers)
T5 Approach (unified text-to-text):
- Classification: "classify: [text]" → "positive"
- Translation: "translate English to German: [text]" → "[German text]"
- QA: "question: [q] context: [c]" → "[answer]"
- Summarization: "summarize: [document]" → "[summary]"
(Same architecture for all tasks!)
"""
class TextToTextExample:
"""Demonstrate T5's text-to-text formulation"""
@staticmethod
def format_task(task_name, input_text, output_text=None):
"""Format different tasks as text-to-text"""
formats = {
"translation": {
"prefix": "translate English to French:",
"input": input_text,
"output": output_text or "Bonjour, comment allez-vous?"
},
"summarization": {
"prefix": "summarize:",
"input": input_text,
"output": output_text or "[brief summary]"
},
"sentiment": {
"prefix": "sentiment:",
"input": input_text,
"output": output_text or "positive"
},
"cola": { # Grammatical acceptability
"prefix": "cola sentence:",
"input": input_text,
"output": output_text or "acceptable"
},
"stsb": { # Semantic similarity (0-5 scale)
"prefix": "stsb sentence1: [s1] sentence2: [s2]",
"input": input_text,
"output": output_text or "4.2"
}
}
return formats.get(task_name, {})
# Examples of text-to-text formatting
def demonstrate_t5_tasks():
"""Show how T5 formats different NLP tasks"""
tasks = [
{
"task": "Translation",
"input": "translate English to German: The house is wonderful.",
"output": "Das Haus ist wunderbar."
},
{
"task": "Summarization",
"input": "summarize: The tech industry has seen unprecedented growth...",
"output": "Tech industry shows strong growth in recent years."
},
{
"task": "Sentiment Classification",
"input": "sentiment: This movie was absolutely terrible!",
"output": "negative"
},
{
"task": "Question Answering",
"input": "question: What is the capital of France? context: Paris is the capital and most populous city of France.",
"output": "Paris"
},
{
"task": "Grammar Acceptability",
"input": "cola sentence: The book was reading by John.",
"output": "unacceptable"
}
]
print("T5 Text-to-Text Task Formatting:\n")
for task in tasks:
print(f"{task['task']}:")
print(f" Input: {task['input']}")
print(f" Output: {task['output']}\n")
demonstrate_t5_tasks()
Unified Architecture: T5's text-to-text framework means the same model weights and architecture can handle any task - you just need to format the input appropriately with task prefixes.
T5 Architecture
Model Design
"""
T5 Architecture Specifications:
T5 uses encoder-decoder transformer (like original Transformer)
Model Variants:
- T5-Small: 60M parameters (6 layers each, d_model=512)
- T5-Base: 220M parameters (12 layers each, d_model=768)
- T5-Large: 770M parameters (24 layers each, d_model=1024)
- T5-3B: 3B parameters (24 layers each, d_model=1024, d_ff=16384)
- T5-11B: 11B parameters (24 layers each, d_model=1024, d_ff=65536)
Key architectural details:
- Relative position embeddings (not absolute)
- Simplified layer normalization (no bias, no subtraction)
- Uses GELU activation
- Pre-layer normalization
"""
import torch
import torch.nn as nn
class T5RelativePositionBias(nn.Module):
"""
T5's relative position bias instead of absolute position embeddings
Instead of adding position embeddings to inputs, T5 adds a bias
to the attention logits based on relative positions
"""
def __init__(self, num_heads, num_buckets=32, max_distance=128):
super().__init__()
self.num_heads = num_heads
self.num_buckets = num_buckets
self.max_distance = max_distance
# Relative attention bias table
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
@staticmethod
def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
"""Map relative positions to buckets"""
# Half buckets for exact positions (0 to num_buckets//2 - 1)
# Half buckets for log-spaced positions beyond that
num_buckets //= 2
ret = (relative_position >= 0).to(torch.long) * num_buckets
n = torch.abs(relative_position)
# Exact positions for small distances
max_exact = num_buckets // 2
is_small = n < max_exact
# Log-spaced positions for large distances
val_if_large = max_exact + (
torch.log(n.float() / max_exact) /
torch.log(max_distance / max_exact) *
(num_buckets - max_exact)
).long()
val_if_large = torch.min(
val_if_large,
torch.full_like(val_if_large, num_buckets - 1)
)
ret = ret + torch.where(is_small, n, val_if_large)
return ret
def forward(self, query_length, key_length):
"""Compute relative position bias"""
# Create relative position matrix
query_pos = torch.arange(query_length, dtype=torch.long)
key_pos = torch.arange(key_length, dtype=torch.long)
relative_position = key_pos[None, :] - query_pos[:, None]
# Map to buckets
buckets = self._relative_position_bucket(
relative_position,
num_buckets=self.num_buckets,
max_distance=self.max_distance
)
# Get bias values
bias = self.relative_attention_bias(buckets) # [q_len, k_len, heads]
bias = bias.permute(2, 0, 1).unsqueeze(0) # [1, heads, q_len, k_len]
return bias
class T5LayerNorm(nn.Module):
"""
T5's simplified layer normalization
- No bias term
- No mean subtraction (only variance normalization)
"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x):
# Variance normalization only
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight * x
class T5Attention(nn.Module):
"""T5 multi-head attention with relative position bias"""
def __init__(self, d_model, num_heads, is_decoder=False, has_relative_bias=True):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_kv = d_model // num_heads
self.is_decoder = is_decoder
# Q, K, V projections
self.q = nn.Linear(d_model, d_model, bias=False)
self.k = nn.Linear(d_model, d_model, bias=False)
self.v = nn.Linear(d_model, d_model, bias=False)
self.o = nn.Linear(d_model, d_model, bias=False)
# Relative position bias (only in first layer of each block)
if has_relative_bias:
self.relative_position_bias = T5RelativePositionBias(num_heads)
else:
self.relative_position_bias = None
def forward(self, hidden_states, key_value_states=None, mask=None,
position_bias=None):
batch_size, seq_len, _ = hidden_states.shape
# Self-attention or cross-attention
if key_value_states is None:
key_value_states = hidden_states
# Project Q, K, V
q = self.q(hidden_states).view(batch_size, -1, self.num_heads, self.d_kv)
k = self.k(key_value_states).view(batch_size, -1, self.num_heads, self.d_kv)
v = self.v(key_value_states).view(batch_size, -1, self.num_heads, self.d_kv)
# Transpose for attention: [batch, heads, seq, d_kv]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Compute attention scores
scores = torch.matmul(q, k.transpose(-2, -1))
# Add position bias
if position_bias is None and self.relative_position_bias is not None:
position_bias = self.relative_position_bias(seq_len, k.size(2))
if position_bias is not None:
scores = scores + position_bias
# Apply mask if provided
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax and apply to values
attn_weights = torch.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, v)
# Reshape and project
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
output = self.o(attn_output)
return output, position_bias
class T5Block(nn.Module):
"""T5 encoder/decoder block"""
def __init__(self, d_model, num_heads, d_ff, is_decoder=False, dropout=0.1):
super().__init__()
self.is_decoder = is_decoder
# Self-attention
self.self_attn = T5Attention(d_model, num_heads, is_decoder, has_relative_bias=True)
self.self_attn_norm = T5LayerNorm(d_model)
# Cross-attention (decoder only)
if is_decoder:
self.cross_attn = T5Attention(d_model, num_heads, has_relative_bias=False)
self.cross_attn_norm = T5LayerNorm(d_model)
# Feed-forward
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff, bias=False),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model, bias=False),
nn.Dropout(dropout)
)
self.ffn_norm = T5LayerNorm(d_model)
def forward(self, hidden_states, encoder_hidden_states=None,
self_attn_mask=None, cross_attn_mask=None, position_bias=None):
# Self-attention with pre-normalization
normed_hidden = self.self_attn_norm(hidden_states)
attn_output, position_bias = self.self_attn(
normed_hidden, mask=self_attn_mask, position_bias=position_bias
)
hidden_states = hidden_states + attn_output
# Cross-attention (decoder only)
if self.is_decoder and encoder_hidden_states is not None:
normed_hidden = self.cross_attn_norm(hidden_states)
cross_output, _ = self.cross_attn(
normed_hidden,
key_value_states=encoder_hidden_states,
mask=cross_attn_mask
)
hidden_states = hidden_states + cross_output
# Feed-forward
normed_hidden = self.ffn_norm(hidden_states)
ffn_output = self.ffn(normed_hidden)
hidden_states = hidden_states + ffn_output
return hidden_states, position_bias
# Model instantiation
def create_t5_model(variant='base'):
"""Create T5 model of specified size"""
configs = {
'small': {'d_model': 512, 'num_layers': 6, 'd_ff': 2048, 'num_heads': 8},
'base': {'d_model': 768, 'num_layers': 12, 'd_ff': 3072, 'num_heads': 12},
'large': {'d_model': 1024, 'num_layers': 24, 'd_ff': 4096, 'num_heads': 16},
}
config = configs[variant]
print(f"T5-{variant.capitalize()} Configuration:")
print(f" Model dimension: {config['d_model']}")
print(f" Layers: {config['num_layers']}")
print(f" Feed-forward dimension: {config['d_ff']}")
print(f" Attention heads: {config['num_heads']}")
return config
create_t5_model('base')
Relative Position Bias: Instead of adding position information to input embeddings, T5 adds learned biases to attention scores based on relative distances. This allows better generalization to sequences longer than those seen during training.
C4 Dataset
C4 (Colossal Clean Crawled Corpus): A massive 750GB dataset created from Common Crawl web data, cleaned using filters for language, quality, and appropriateness. The cleaned nature significantly improves model performance compared to raw web text.
T5's training data: the Colossal Clean Crawled Corpus.
Dataset Creation
"""
C4 Dataset (Colossal Clean Crawled Corpus):
Source: Common Crawl (web scraping project)
Size: ~750GB of cleaned English text
Cleaning Pipeline:
1. Language filtering (keep only English)
2. Remove duplicate lines
3. Remove sentences with "bad words"
4. Remove code (lines with curly braces)
5. Retain sentences ending with punctuation
6. Remove pages with < 5 sentences
7. Remove blacklisted websites
Result: High-quality, diverse web text for pre-training
"""
class C4DatasetProcessor:
"""Simulate C4 cleaning pipeline"""
def __init__(self):
self.min_sentences = 5
self.bad_words = {'badword1', 'badword2'} # Simplified
def is_english(self, text):
"""Check if text is English (simplified)"""
# In reality, uses langdetect or similar
return True # Placeholder
def remove_duplicates(self, lines):
"""Remove duplicate lines"""
seen = set()
unique_lines = []
for line in lines:
if line not in seen:
seen.add(line)
unique_lines.append(line)
return unique_lines
def has_bad_words(self, text):
"""Check for inappropriate content"""
words = text.lower().split()
return any(word in self.bad_words for word in words)
def is_code(self, line):
"""Detect code-like content"""
# Simple heuristic: lines with { or }
return '{' in line or '}' in line
def ends_with_punctuation(self, sentence):
"""Check if sentence ends with terminal punctuation"""
return sentence.rstrip().endswith(('.', '!', '?'))
def clean_page(self, text):
"""Apply C4 cleaning pipeline to a web page"""
# Split into lines
lines = text.split('\n')
# Remove duplicates
lines = self.remove_duplicates(lines)
# Filter lines
clean_lines = []
for line in lines:
line = line.strip()
# Skip empty lines
if not line:
continue
# Skip code
if self.is_code(line):
continue
# Skip bad words
if self.has_bad_words(line):
continue
# Keep only sentences ending with punctuation
if self.ends_with_punctuation(line):
clean_lines.append(line)
# Reject pages with too few sentences
if len(clean_lines) < self.min_sentences:
return None
return '\n'.join(clean_lines)
# Demonstrate cleaning
def demonstrate_c4_cleaning():
"""Show C4 cleaning process"""
processor = C4DatasetProcessor()
raw_text = """
Welcome to our website!
This is a great product
Check out this code: function() { return true; }
We offer the best services.
Click here now!
Buy now for only $9.99!
"""
print("Raw text:")
print(raw_text)
print("\nAfter C4 cleaning:")
cleaned = processor.clean_page(raw_text)
if cleaned:
print(cleaned)
else:
print("Page rejected (too few sentences)")
demonstrate_c4_cleaning()
# C4 statistics
c4_stats = """
C4 Dataset Statistics:
Size:
- ~750GB of text
- ~365 billion tokens
- ~156 billion words
Coverage:
- Sourced from April 2019 Common Crawl
- ~15 million domains
- Highly diverse topics and styles
Comparison to other datasets:
- Wikipedia: ~20GB
- BookCorpus: ~5GB
- C4: ~750GB (37x larger than Wikipedia)
Quality vs Scale:
- More data generally improves performance
- But data quality matters too
- C4 balances both: large AND clean
"""
print(c4_stats)
Data Quality Matters: T5 experiments showed that training on C4 (cleaned web text) significantly outperformed training on raw Common Crawl, demonstrating that data quality is as important as quantity.
Pre-training Objectives
Span Corruption: T5's pre-training objective where contiguous spans of tokens are replaced with sentinel tokens, and the model must predict the original content at each sentinel position. This is more effective than single-token masking for learning context.
T5 explored multiple pre-training approaches.
Objective Comparison
"""
T5 evaluated multiple pre-training objectives:
1. BERT-style (Fill in the blank):
Input: "Thank you for inviting me to your party last week."
Target: "me to your party"
2. I.I.D. Denoising:
Input: "Thank you <M> inviting <M> to <M> party <M> week."
Target: "<M> for <M> me <M> your <M> last <M>"
3. Replace Corrupted Spans (CHOSEN):
Input: "Thank you <X> to <Y> party <Z> week."
Target: "<X> for inviting me <Y> your <Z> last <X>"
Result: Replace spans worked best!
"""
class T5PretrainingObjective:
"""T5's span corruption pre-training"""
def __init__(self, corruption_rate=0.15, mean_span_length=3):
self.corruption_rate = corruption_rate
self.mean_span_length = mean_span_length
def corrupt_spans(self, text, vocab):
"""
Corrupt spans of tokens for T5 pre-training
Args:
text: List of tokens
vocab: Vocabulary (for sentinel tokens)
Returns:
corrupted_input: Input with corrupted spans replaced by sentinels
targets: Target sequence with sentinels and original spans
"""
import numpy as np
tokens = text.split()
num_tokens = len(tokens)
# Determine which tokens to corrupt
num_corrupt = int(num_tokens * self.corruption_rate)
# Sample span starts
span_starts = []
corrupted_indices = set()
while len(corrupted_indices) < num_corrupt:
# Sample span start
start = np.random.randint(0, num_tokens)
# Sample span length (geometric distribution)
length = np.random.geometric(1.0 / self.mean_span_length)
length = min(length, num_tokens - start)
# Add span indices
span_indices = range(start, start + length)
span_starts.append(start)
corrupted_indices.update(span_indices)
# Sort spans by position
span_starts.sort()
# Create input and target
input_tokens = []
target_tokens = []
sentinel_id = 0
last_end = 0
for start in span_starts:
# Add uncorrupted tokens before span
input_tokens.extend(tokens[last_end:start])
# Add sentinel to input
sentinel = f"<extra_id_{sentinel_id}>"
input_tokens.append(sentinel)
# Find span end
end = start + 1
while end in corrupted_indices:
end += 1
# Add sentinel and original span to target
target_tokens.append(sentinel)
target_tokens.extend(tokens[start:end])
sentinel_id += 1
last_end = end
# Add remaining uncorrupted tokens
input_tokens.extend(tokens[last_end:])
# Add final sentinel to target
target_tokens.append(f"<extra_id_{sentinel_id}>")
return ' '.join(input_tokens), ' '.join(target_tokens)
# Demonstrate span corruption
def demonstrate_span_corruption():
"""Show T5's span corruption objective"""
objective = T5PretrainingObjective(corruption_rate=0.15, mean_span_length=3)
original = "Thank you for inviting me to your party last week"
print("T5 Span Corruption Pre-training:\n")
print(f"Original: {original}\n")
# Generate multiple corrupted versions
for i in range(3):
corrupted_input, target = objective.corrupt_spans(original, vocab={})
print(f"Example {i+1}:")
print(f" Input: {corrupted_input}")
print(f" Target: {target}\n")
demonstrate_span_corruption()
# Objective comparison from T5 paper
objective_comparison = """
T5 Pre-training Objective Ablation Results:
Objective | GLUE Score | SQuAD EM
---------------------------|------------|----------
BERT-style (full mask) | 83.2 | 80.1
Deshuffling | 82.9 | 79.5
MASS (50% mask) | 83.7 | 80.8
Replace Spans (T5) | 84.1 | 81.3 ← Best!
Corruption Rate Ablation:
- 10%: Good
- 15%: Better
- 25%: Best
- 50%: Worse (too much corruption)
Mean Span Length:
- 1 token: Similar to BERT
- 3 tokens: Best performance ← T5 default
- 10 tokens: Worse
Conclusion: Replace corrupted spans with 15% corruption
and mean span length of 3 tokens works best.
"""
print(objective_comparison)
Span Corruption Insight: Corrupting spans instead of individual tokens forces the model to understand longer-range context and dependencies, leading to better representations.
Transfer Learning with T5
Multi-task Training
"""
T5 Training Strategy:
1. Pre-training:
- Unsupervised span corruption on C4
- Learns general language understanding
2. Multi-task Fine-tuning (optional):
- Train on mixture of supervised tasks
- Format all as text-to-text
- Improves generalization
3. Task-specific Fine-tuning:
- Fine-tune on target task
- Usually gives best performance
"""
class T5MultiTaskTraining:
"""T5 multi-task training setup"""
def __init__(self):
self.tasks = {
'translation': {
'prefix': 'translate English to French:',
'dataset': 'WMT',
'examples': 1000000
},
'summarization': {
'prefix': 'summarize:',
'dataset': 'CNN/DailyMail',
'examples': 300000
},
'question_answering': {
'prefix': 'question: {q} context: {c}',
'dataset': 'SQuAD',
'examples': 100000
},
'sentiment': {
'prefix': 'sentiment:',
'dataset': 'SST-2',
'examples': 67000
}
}
def create_task_mixture(self, mixing_strategy='proportional'):
"""
Create mixture of tasks for multi-task training
Strategies:
- equal: Sample equally from each task
- proportional: Sample proportional to dataset size
- temperature: Use temperature to adjust sampling
"""
if mixing_strategy == 'equal':
# Each task has equal probability
for task in self.tasks:
self.tasks[task]['weight'] = 1.0 / len(self.tasks)
elif mixing_strategy == 'proportional':
# Sample proportional to dataset size
total_examples = sum(t['examples'] for t in self.tasks.values())
for task in self.tasks:
self.tasks[task]['weight'] = (
self.tasks[task]['examples'] / total_examples
)
elif mixing_strategy == 'temperature':
# Use temperature to flatten/sharpen distribution
temperature = 2.0
total = sum(t['examples'] ** (1/temperature) for t in self.tasks.values())
for task in self.tasks:
weight = (self.tasks[task]['examples'] ** (1/temperature)) / total
self.tasks[task]['weight'] = weight
return self.tasks
# Demonstrate multi-task training
def demonstrate_multitask():
"""Show T5 multi-task training setup"""
trainer = T5MultiTaskTraining()
print("T5 Multi-task Training:\n")
for strategy in ['equal', 'proportional', 'temperature']:
print(f"{strategy.capitalize()} mixing:")
tasks = trainer.create_task_mixture(strategy)
for task_name, task_info in tasks.items():
print(f" {task_name}: {task_info['weight']:.3f}")
print()
demonstrate_multitask()
Fine-tuning Examples
# Fine-tune T5 for specific tasks
def finetune_t5_summarization():
"""Fine-tune T5 for summarization"""
from transformers import T5ForConditionalGeneration, T5Tokenizer
model = T5ForConditionalGeneration.from_pretrained('t5-base')
tokenizer = T5Tokenizer.from_pretrained('t5-base')
# Example training data
documents = [
"The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower.",
"Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed."
]
summaries = [
"The Eiffel Tower in Paris was designed by Gustave Eiffel.",
"Machine learning allows systems to learn from experience without explicit programming."
]
print("T5 Summarization Fine-tuning:\n")
for doc, summary in zip(documents, summaries):
# Format as text-to-text
input_text = f"summarize: {doc}"
# Tokenize
input_ids = tokenizer(input_text, return_tensors='pt',
max_length=512, truncation=True).input_ids
labels = tokenizer(summary, return_tensors='pt',
max_length=128, truncation=True).input_ids
# Training step (simplified)
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
print(f"Input length: {input_ids.shape[1]}")
print(f"Target length: {labels.shape[1]}")
print(f"Loss: {loss.item():.4f}\n")
# Inference
print("Inference Example:")
test_doc = "Neural networks are computing systems inspired by biological neural networks. They learn from examples without being programmed with task-specific rules."
input_ids = tokenizer(f"summarize: {test_doc}",
return_tensors='pt').input_ids
outputs = model.generate(input_ids, max_length=50, num_beams=4,
early_stopping=True)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Document: {test_doc}")
print(f"Summary: {summary}")
finetune_t5_summarization()
# Translation example
def t5_translation_example():
"""Use T5 for translation"""
from transformers import T5ForConditionalGeneration, T5Tokenizer
model = T5ForConditionalGeneration.from_pretrained('t5-base')
tokenizer = T5Tokenizer.from_pretrained('t5-base')
# English to German
text = "translate English to German: The house is wonderful."
input_ids = tokenizer(text, return_tensors='pt').input_ids
outputs = model.generate(input_ids, max_length=40)
translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"\nT5 Translation:")
print(f"English: The house is wonderful.")
print(f"German: {translation}")
t5_translation_example()
T5 Variants
"""
T5 Model Family:
1. T5 (2019):
- Original text-to-text model
- 5 sizes: Small (60M) to 11B
- Trained on C4
2. mT5 (Multilingual T5, 2020):
- Supports 101 languages
- Trained on mC4 (multilingual C4)
- Same architecture, multilingual data
3. ByT5 (Byte-level T5, 2021):
- Operates on raw bytes instead of subword tokens
- No vocabulary needed
- Better for multilingual and noisy text
4. Flan-T5 (2022):
- Instruction-tuned T5
- Trained on 1,800+ tasks formatted as instructions
- Much better zero-shot and few-shot performance
5. UL2 (Unified Language Learner, 2022):
- Unifies different denoising objectives
- Better performance than T5
- Later adapted to T5 architecture (Flan-UL2)
"""
# Model comparison
def compare_t5_variants():
"""Compare T5 model variants"""
variants = {
'T5-Base': {
'params': '220M',
'languages': '1 (English)',
'tokenization': 'SentencePiece (32K vocab)',
'use_case': 'General English NLP'
},
'mT5-Base': {
'params': '580M',
'languages': '101',
'tokenization': 'SentencePiece (250K vocab)',
'use_case': 'Multilingual tasks'
},
'ByT5-Base': {
'params': '300M',
'languages': 'All (byte-level)',
'tokenization': 'Bytes (256 vocab)',
'use_case': 'Noisy/multilingual text'
},
'Flan-T5-Base': {
'params': '250M',
'languages': '1 (English)',
'tokenization': 'SentencePiece (32K vocab)',
'use_case': 'Instruction following'
}
}
print("T5 Variants Comparison:\n")
for name, specs in variants.items():
print(f"{name}:")
for key, value in specs.items():
print(f" {key}: {value}")
print()
compare_t5_variants()
# Flan-T5 instruction format
flan_t5_examples = """
Flan-T5 Instruction Format:
Instead of task prefixes, uses natural instructions:
Standard T5:
"translate English to French: Hello"
Flan-T5:
"Please translate the following sentence to French: Hello"
OR
"Convert this English text to French: Hello"
OR
"What is the French translation of: Hello"
Benefits:
- More natural interaction
- Better zero-shot generalization
- Instruction diversity improves robustness
- Can follow novel instructions not seen during training
"""
print(flan_t5_examples)
Flan-T5 Performance: Flan-T5 shows remarkable zero-shot performance on tasks it wasn't explicitly trained on, often matching or exceeding much larger models. This makes it highly practical for real-world applications.
Practice Exercise
# Exercise: Implement task-specific formatters for T5
class T5TaskFormatter:
"""Format different NLP tasks for T5"""
@staticmethod
def format_classification(text, label=None, task_name="sentiment"):
"""Format classification task"""
input_text = f"{task_name}: {text}"
output_text = label if label else None
return input_text, output_text
@staticmethod
def format_ner(text, entities=None):
"""Format NER task"""
input_text = f"extract entities: {text}"
if entities:
# Format: "PER: John, ORG: Google"
entity_strings = [f"{ent['type']}: {ent['text']}"
for ent in entities]
output_text = ", ".join(entity_strings)
else:
output_text = None
return input_text, output_text
@staticmethod
def format_qa(question, context, answer=None):
"""Format QA task"""
input_text = f"question: {question} context: {context}"
output_text = answer if answer else None
return input_text, output_text
@staticmethod
def format_paraphrase(text, paraphrase=None):
"""Format paraphrase generation"""
input_text = f"paraphrase: {text}"
output_text = paraphrase if paraphrase else None
return input_text, output_text
# Demonstrate formatting
def demonstrate_task_formatting():
"""Show T5 task formatting"""
formatter = T5TaskFormatter()
print("T5 Task Formatting Examples:\n")
# Classification
input_text, output = formatter.format_classification(
"This product is amazing!",
label="positive",
task_name="sentiment"
)
print("Classification:")
print(f" Input: {input_text}")
print(f" Output: {output}\n")
# NER
entities = [
{"type": "PER", "text": "John Smith"},
{"type": "ORG", "text": "Google"}
]
input_text, output = formatter.format_ner(
"John Smith works at Google.",
entities=entities
)
print("Named Entity Recognition:")
print(f" Input: {input_text}")
print(f" Output: {output}\n")
# QA
input_text, output = formatter.format_qa(
question="What is the capital of France?",
context="Paris is the capital and largest city of France.",
answer="Paris"
)
print("Question Answering:")
print(f" Input: {input_text}")
print(f" Output: {output}\n")
demonstrate_task_formatting()
# Exercise questions
exercise_questions = """
Practice Exercises:
1. Why does T5 use relative position bias instead of absolute
position embeddings? What advantage does this provide?
2. Design a text-to-text format for a table-to-text task
(converting structured data to natural language).
3. Calculate: If T5-Base has 12 encoder and 12 decoder layers,
and each layer has ~18M parameters, estimate total parameters.
4. Compare: When would you use T5 vs BERT vs GPT? List use cases
for each architecture.
5. Implement: Create a corrupted span example with 20% corruption
rate and mean span length of 2 tokens.
"""
print(exercise_questions)
Quiz
Further Reading
- Exploring the Limits of Transfer Learning with T5 (Original Paper)
- mT5: A Massively Multilingual Pre-trained Text-to-Text Transformer
- ByT5: Towards a Token-Free Future with Pre-trained Byte-to-Byte Models
- Scaling Instruction-Finetuned Language Models (Flan-T5)
- The C4 Dataset Documentation
- T5 Model Documentation (Hugging Face)